Skip to main content

survival/validation/
landmark.rs

1use crate::constants::{PARALLEL_THRESHOLD_SMALL, z_score_for_confidence};
2use crate::utilities::statistical::normal_cdf as norm_cdf;
3use pyo3::prelude::*;
4use rayon::prelude::*;
5#[derive(Debug, Clone)]
6#[pyclass]
7pub struct LandmarkResult {
8    #[pyo3(get)]
9    pub landmark_time: f64,
10    #[pyo3(get)]
11    pub n_at_risk: usize,
12    #[pyo3(get)]
13    pub n_excluded: usize,
14    #[pyo3(get)]
15    pub time: Vec<f64>,
16    #[pyo3(get)]
17    pub status: Vec<i32>,
18    #[pyo3(get)]
19    pub original_indices: Vec<usize>,
20}
21#[pymethods]
22impl LandmarkResult {
23    #[new]
24    fn new(
25        landmark_time: f64,
26        n_at_risk: usize,
27        n_excluded: usize,
28        time: Vec<f64>,
29        status: Vec<i32>,
30        original_indices: Vec<usize>,
31    ) -> Self {
32        Self {
33            landmark_time,
34            n_at_risk,
35            n_excluded,
36            time,
37            status,
38            original_indices,
39        }
40    }
41}
42pub fn compute_landmark(time: &[f64], status: &[i32], landmark_time: f64) -> LandmarkResult {
43    let n = time.len();
44    let mut new_time = Vec::new();
45    let mut new_status = Vec::new();
46    let mut original_indices = Vec::new();
47    let mut n_excluded = 0usize;
48    for i in 0..n {
49        if time[i] > landmark_time {
50            new_time.push(time[i] - landmark_time);
51            new_status.push(status[i]);
52            original_indices.push(i);
53        } else {
54            n_excluded += 1;
55        }
56    }
57    let n_at_risk = new_time.len();
58    LandmarkResult {
59        landmark_time,
60        n_at_risk,
61        n_excluded,
62        time: new_time,
63        status: new_status,
64        original_indices,
65    }
66}
67#[pyfunction]
68pub fn landmark_analysis(
69    time: Vec<f64>,
70    status: Vec<i32>,
71    landmark_time: f64,
72) -> PyResult<LandmarkResult> {
73    Ok(compute_landmark(&time, &status, landmark_time))
74}
75pub fn compute_landmarks_parallel(
76    time: &[f64],
77    status: &[i32],
78    landmark_times: &[f64],
79) -> Vec<LandmarkResult> {
80    landmark_times
81        .par_iter()
82        .map(|&lt| compute_landmark(time, status, lt))
83        .collect()
84}
85#[pyfunction]
86pub fn landmark_analysis_batch(
87    time: Vec<f64>,
88    status: Vec<i32>,
89    landmark_times: Vec<f64>,
90) -> PyResult<Vec<LandmarkResult>> {
91    Ok(compute_landmarks_parallel(&time, &status, &landmark_times))
92}
93#[derive(Debug, Clone)]
94#[pyclass]
95pub struct ConditionalSurvivalResult {
96    #[pyo3(get)]
97    pub given_time: f64,
98    #[pyo3(get)]
99    pub target_time: f64,
100    #[pyo3(get)]
101    pub conditional_survival: f64,
102    #[pyo3(get)]
103    pub ci_lower: f64,
104    #[pyo3(get)]
105    pub ci_upper: f64,
106    #[pyo3(get)]
107    pub n_at_risk: usize,
108}
109#[pymethods]
110impl ConditionalSurvivalResult {
111    #[new]
112    fn new(
113        given_time: f64,
114        target_time: f64,
115        conditional_survival: f64,
116        ci_lower: f64,
117        ci_upper: f64,
118        n_at_risk: usize,
119    ) -> Self {
120        Self {
121            given_time,
122            target_time,
123            conditional_survival,
124            ci_lower,
125            ci_upper,
126            n_at_risk,
127        }
128    }
129}
130pub fn compute_conditional_survival(
131    time: &[f64],
132    status: &[i32],
133    given_time: f64,
134    target_time: f64,
135    confidence_level: f64,
136) -> ConditionalSurvivalResult {
137    let n = time.len();
138    if n == 0 || target_time <= given_time {
139        return ConditionalSurvivalResult {
140            given_time,
141            target_time,
142            conditional_survival: 1.0,
143            ci_lower: 1.0,
144            ci_upper: 1.0,
145            n_at_risk: 0,
146        };
147    }
148    let mut indices: Vec<usize> = (0..n).collect();
149    indices.sort_by(|&a, &b| {
150        time[a]
151            .partial_cmp(&time[b])
152            .unwrap_or(std::cmp::Ordering::Equal)
153    });
154    let mut surv_given = 1.0;
155    let mut surv_target = 1.0;
156    let mut var_given = 0.0;
157    let mut var_target = 0.0;
158    let mut total_at_risk = n as f64;
159    let mut n_at_given = 0usize;
160    let mut i = 0;
161    while i < n {
162        let current_time = time[indices[i]];
163        let mut events = 0.0;
164        let mut removed = 0.0;
165        while i < n && time[indices[i]] == current_time {
166            removed += 1.0;
167            if status[indices[i]] == 1 {
168                events += 1.0;
169            }
170            i += 1;
171        }
172        if events > 0.0 && total_at_risk > 0.0 {
173            let hazard = events / total_at_risk;
174            if current_time <= given_time {
175                surv_given *= 1.0 - hazard;
176                if total_at_risk > events {
177                    var_given += events / (total_at_risk * (total_at_risk - events));
178                }
179            }
180            if current_time <= target_time {
181                surv_target *= 1.0 - hazard;
182                if total_at_risk > events {
183                    var_target += events / (total_at_risk * (total_at_risk - events));
184                }
185            }
186        }
187        if current_time <= given_time {
188            n_at_given = (total_at_risk - removed) as usize;
189        }
190        total_at_risk -= removed;
191    }
192    let conditional = if surv_given > 0.0 {
193        surv_target / surv_given
194    } else {
195        0.0
196    };
197    let z = z_score_for_confidence(confidence_level);
198    let var_conditional = if surv_given > 0.0 {
199        conditional * conditional * (var_target - var_given).abs()
200    } else {
201        0.0
202    };
203    let se = var_conditional.sqrt();
204    let ci_lower = (conditional - z * se).clamp(0.0, 1.0);
205    let ci_upper = (conditional + z * se).clamp(0.0, 1.0);
206    ConditionalSurvivalResult {
207        given_time,
208        target_time,
209        conditional_survival: conditional,
210        ci_lower,
211        ci_upper,
212        n_at_risk: n_at_given,
213    }
214}
215#[pyfunction]
216#[pyo3(signature = (time, status, given_time, target_time, confidence_level=None))]
217pub fn conditional_survival(
218    time: Vec<f64>,
219    status: Vec<i32>,
220    given_time: f64,
221    target_time: f64,
222    confidence_level: Option<f64>,
223) -> PyResult<ConditionalSurvivalResult> {
224    let conf = confidence_level.unwrap_or(0.95);
225    Ok(compute_conditional_survival(
226        &time,
227        &status,
228        given_time,
229        target_time,
230        conf,
231    ))
232}
233#[derive(Debug, Clone)]
234#[pyclass]
235pub struct HazardRatioResult {
236    #[pyo3(get)]
237    pub hazard_ratio: f64,
238    #[pyo3(get)]
239    pub ci_lower: f64,
240    #[pyo3(get)]
241    pub ci_upper: f64,
242    #[pyo3(get)]
243    pub se_log_hr: f64,
244    #[pyo3(get)]
245    pub z_statistic: f64,
246    #[pyo3(get)]
247    pub p_value: f64,
248}
249#[pymethods]
250impl HazardRatioResult {
251    #[new]
252    fn new(
253        hazard_ratio: f64,
254        ci_lower: f64,
255        ci_upper: f64,
256        se_log_hr: f64,
257        z_statistic: f64,
258        p_value: f64,
259    ) -> Self {
260        Self {
261            hazard_ratio,
262            ci_lower,
263            ci_upper,
264            se_log_hr,
265            z_statistic,
266            p_value,
267        }
268    }
269}
270pub fn compute_hazard_ratio(
271    time: &[f64],
272    status: &[i32],
273    group: &[i32],
274    confidence_level: f64,
275) -> HazardRatioResult {
276    let n = time.len();
277    if n == 0 {
278        return HazardRatioResult {
279            hazard_ratio: 1.0,
280            ci_lower: 1.0,
281            ci_upper: 1.0,
282            se_log_hr: 0.0,
283            z_statistic: 0.0,
284            p_value: 1.0,
285        };
286    }
287    let mut unique_groups: Vec<i32> = group.to_vec();
288    unique_groups.sort();
289    unique_groups.dedup();
290    if unique_groups.len() < 2 {
291        return HazardRatioResult {
292            hazard_ratio: 1.0,
293            ci_lower: 1.0,
294            ci_upper: 1.0,
295            se_log_hr: 0.0,
296            z_statistic: 0.0,
297            p_value: 1.0,
298        };
299    }
300    let g1 = unique_groups[0];
301    let g2 = unique_groups[1];
302    let mut indices: Vec<usize> = (0..n).collect();
303    indices.sort_by(|&a, &b| {
304        time[a]
305            .partial_cmp(&time[b])
306            .unwrap_or(std::cmp::Ordering::Equal)
307    });
308    let mut n1_at_risk = 0.0;
309    let mut n2_at_risk = 0.0;
310    for &grp in group {
311        if grp == g1 {
312            n1_at_risk += 1.0;
313        } else if grp == g2 {
314            n2_at_risk += 1.0;
315        }
316    }
317    let mut sum_o_e: f64 = 0.0;
318    let mut sum_var: f64 = 0.0;
319    let mut i = 0;
320    while i < n {
321        let current_time = time[indices[i]];
322        let mut d1 = 0.0;
323        let mut d2 = 0.0;
324        let mut r1 = 0.0;
325        let mut r2 = 0.0;
326        while i < n && time[indices[i]] == current_time {
327            let idx = indices[i];
328            if group[idx] == g1 {
329                r1 += 1.0;
330                if status[idx] == 1 {
331                    d1 += 1.0;
332                }
333            } else if group[idx] == g2 {
334                r2 += 1.0;
335                if status[idx] == 1 {
336                    d2 += 1.0;
337                }
338            }
339            i += 1;
340        }
341        let d = d1 + d2;
342        let y = n1_at_risk + n2_at_risk;
343        if d > 0.0 && y > 1.0 {
344            let e1 = d * n1_at_risk / y;
345            sum_o_e += d1 - e1;
346            let v = d * n1_at_risk * n2_at_risk * (y - d) / (y * y * (y - 1.0));
347            sum_var += v;
348        }
349        n1_at_risk -= r1;
350        n2_at_risk -= r2;
351    }
352    let log_hr: f64 = if sum_var > 0.0 {
353        sum_o_e / sum_var
354    } else {
355        0.0
356    };
357    let hazard_ratio = log_hr.exp();
358    let se_log_hr: f64 = if sum_var > 0.0 {
359        1.0 / sum_var.sqrt()
360    } else {
361        0.0
362    };
363    let z: f64 = if confidence_level >= 0.99 {
364        2.576
365    } else if confidence_level >= 0.95 {
366        1.96
367    } else if confidence_level >= 0.90 {
368        1.645
369    } else {
370        1.28
371    };
372    let ci_lower = (log_hr - z * se_log_hr).exp();
373    let ci_upper = (log_hr + z * se_log_hr).exp();
374    let z_statistic: f64 = if se_log_hr > 0.0 {
375        log_hr / se_log_hr
376    } else {
377        0.0
378    };
379    let p_value = 2.0 * (1.0 - norm_cdf(z_statistic.abs()));
380    HazardRatioResult {
381        hazard_ratio,
382        ci_lower,
383        ci_upper,
384        se_log_hr,
385        z_statistic,
386        p_value,
387    }
388}
389/// Compute hazard ratio between two groups.
390///
391/// Parameters
392/// ----------
393/// time : array-like
394///     Survival/censoring times.
395/// status : array-like
396///     Event indicator (1=event, 0=censored).
397/// group : array-like
398///     Group indicator (0 or 1).
399/// confidence_level : float, optional
400///     Confidence level (default 0.95).
401///
402/// Returns
403/// -------
404/// HazardRatioResult
405///     Object with: hazard_ratio, std_err, conf_lower, conf_upper, p_value.
406#[pyfunction]
407#[pyo3(signature = (time, status, group, confidence_level=None))]
408pub fn hazard_ratio(
409    time: Vec<f64>,
410    status: Vec<i32>,
411    group: Vec<i32>,
412    confidence_level: Option<f64>,
413) -> PyResult<HazardRatioResult> {
414    let conf = confidence_level.unwrap_or(0.95);
415    Ok(compute_hazard_ratio(&time, &status, &group, conf))
416}
417#[derive(Debug, Clone)]
418#[pyclass]
419pub struct SurvivalAtTimeResult {
420    #[pyo3(get)]
421    pub time: f64,
422    #[pyo3(get)]
423    pub survival: f64,
424    #[pyo3(get)]
425    pub ci_lower: f64,
426    #[pyo3(get)]
427    pub ci_upper: f64,
428    #[pyo3(get)]
429    pub n_at_risk: usize,
430    #[pyo3(get)]
431    pub n_events: usize,
432}
433#[pymethods]
434impl SurvivalAtTimeResult {
435    #[new]
436    fn new(
437        time: f64,
438        survival: f64,
439        ci_lower: f64,
440        ci_upper: f64,
441        n_at_risk: usize,
442        n_events: usize,
443    ) -> Self {
444        Self {
445            time,
446            survival,
447            ci_lower,
448            ci_upper,
449            n_at_risk,
450            n_events,
451        }
452    }
453}
454pub fn compute_survival_at_times(
455    time: &[f64],
456    status: &[i32],
457    eval_times: &[f64],
458    confidence_level: f64,
459) -> Vec<SurvivalAtTimeResult> {
460    let n = time.len();
461    if n == 0 {
462        return eval_times
463            .iter()
464            .map(|&t| SurvivalAtTimeResult {
465                time: t,
466                survival: 1.0,
467                ci_lower: 1.0,
468                ci_upper: 1.0,
469                n_at_risk: 0,
470                n_events: 0,
471            })
472            .collect();
473    }
474    let mut indices: Vec<usize> = (0..n).collect();
475    indices.sort_by(|&a, &b| {
476        time[a]
477            .partial_cmp(&time[b])
478            .unwrap_or(std::cmp::Ordering::Equal)
479    });
480    let mut event_times: Vec<f64> = Vec::new();
481    let mut survival_vals: Vec<f64> = Vec::new();
482    let mut var_vals: Vec<f64> = Vec::new();
483    let mut n_risk_vals: Vec<usize> = Vec::new();
484    let mut cum_events: Vec<usize> = Vec::new();
485    let mut surv = 1.0;
486    let mut var_sum = 0.0;
487    let mut total_at_risk = n as f64;
488    let mut total_events = 0usize;
489    let mut i = 0;
490    while i < n {
491        let current_time = time[indices[i]];
492        let mut events = 0.0;
493        let mut removed = 0.0;
494        while i < n && time[indices[i]] == current_time {
495            removed += 1.0;
496            if status[indices[i]] == 1 {
497                events += 1.0;
498                total_events += 1;
499            }
500            i += 1;
501        }
502        if events > 0.0 && total_at_risk > 0.0 {
503            surv *= 1.0 - events / total_at_risk;
504            if total_at_risk > events {
505                var_sum += events / (total_at_risk * (total_at_risk - events));
506            }
507            event_times.push(current_time);
508            survival_vals.push(surv);
509            var_vals.push(surv * surv * var_sum);
510            n_risk_vals.push(total_at_risk as usize);
511            cum_events.push(total_events);
512        }
513        total_at_risk -= removed;
514    }
515    let z = z_score_for_confidence(confidence_level);
516    let results: Vec<SurvivalAtTimeResult> = if eval_times.len() > PARALLEL_THRESHOLD_SMALL {
517        eval_times
518            .par_iter()
519            .map(|&t| {
520                let (survival, var, n_risk, n_ev) = if event_times.is_empty() || t < event_times[0]
521                {
522                    (1.0, 0.0, n, 0)
523                } else {
524                    let idx = event_times.partition_point(|&et| et <= t);
525                    if idx == 0 {
526                        (1.0, 0.0, n, 0)
527                    } else {
528                        (
529                            survival_vals[idx - 1],
530                            var_vals[idx - 1],
531                            n_risk_vals[idx - 1],
532                            cum_events[idx - 1],
533                        )
534                    }
535                };
536                let se = var.sqrt();
537                let ci_lower = (survival - z * se).clamp(0.0, 1.0);
538                let ci_upper = (survival + z * se).clamp(0.0, 1.0);
539                SurvivalAtTimeResult {
540                    time: t,
541                    survival,
542                    ci_lower,
543                    ci_upper,
544                    n_at_risk: n_risk,
545                    n_events: n_ev,
546                }
547            })
548            .collect()
549    } else {
550        let mut results = Vec::with_capacity(eval_times.len());
551        for &t in eval_times {
552            let (survival, var, n_risk, n_ev) = if event_times.is_empty() || t < event_times[0] {
553                (1.0, 0.0, n, 0)
554            } else {
555                let idx = event_times.partition_point(|&et| et <= t);
556                if idx == 0 {
557                    (1.0, 0.0, n, 0)
558                } else {
559                    (
560                        survival_vals[idx - 1],
561                        var_vals[idx - 1],
562                        n_risk_vals[idx - 1],
563                        cum_events[idx - 1],
564                    )
565                }
566            };
567            let se = var.sqrt();
568            let ci_lower = (survival - z * se).clamp(0.0, 1.0);
569            let ci_upper = (survival + z * se).clamp(0.0, 1.0);
570            results.push(SurvivalAtTimeResult {
571                time: t,
572                survival,
573                ci_lower,
574                ci_upper,
575                n_at_risk: n_risk,
576                n_events: n_ev,
577            });
578        }
579        results
580    };
581    results
582}
583#[pyfunction]
584#[pyo3(signature = (time, status, eval_times, confidence_level=None))]
585pub fn survival_at_times(
586    time: Vec<f64>,
587    status: Vec<i32>,
588    eval_times: Vec<f64>,
589    confidence_level: Option<f64>,
590) -> PyResult<Vec<SurvivalAtTimeResult>> {
591    let conf = confidence_level.unwrap_or(0.95);
592    Ok(compute_survival_at_times(&time, &status, &eval_times, conf))
593}
594#[derive(Debug, Clone)]
595#[pyclass]
596pub struct LifeTableResult {
597    #[pyo3(get)]
598    pub interval_start: Vec<f64>,
599    #[pyo3(get)]
600    pub interval_end: Vec<f64>,
601    #[pyo3(get)]
602    pub n_at_risk: Vec<f64>,
603    #[pyo3(get)]
604    pub n_deaths: Vec<f64>,
605    #[pyo3(get)]
606    pub n_censored: Vec<f64>,
607    #[pyo3(get)]
608    pub n_effective: Vec<f64>,
609    #[pyo3(get)]
610    pub hazard: Vec<f64>,
611    #[pyo3(get)]
612    pub survival: Vec<f64>,
613    #[pyo3(get)]
614    pub se_survival: Vec<f64>,
615}
616#[pymethods]
617impl LifeTableResult {
618    #[new]
619    #[allow(clippy::too_many_arguments)]
620    fn new(
621        interval_start: Vec<f64>,
622        interval_end: Vec<f64>,
623        n_at_risk: Vec<f64>,
624        n_deaths: Vec<f64>,
625        n_censored: Vec<f64>,
626        n_effective: Vec<f64>,
627        hazard: Vec<f64>,
628        survival: Vec<f64>,
629        se_survival: Vec<f64>,
630    ) -> Self {
631        Self {
632            interval_start,
633            interval_end,
634            n_at_risk,
635            n_deaths,
636            n_censored,
637            n_effective,
638            hazard,
639            survival,
640            se_survival,
641        }
642    }
643}
644pub fn compute_life_table(time: &[f64], status: &[i32], breaks: &[f64]) -> LifeTableResult {
645    let n = time.len();
646    let n_intervals = breaks.len().saturating_sub(1);
647    if n == 0 || n_intervals == 0 {
648        return LifeTableResult {
649            interval_start: vec![],
650            interval_end: vec![],
651            n_at_risk: vec![],
652            n_deaths: vec![],
653            n_censored: vec![],
654            n_effective: vec![],
655            hazard: vec![],
656            survival: vec![],
657            se_survival: vec![],
658        };
659    }
660    let mut interval_start = Vec::with_capacity(n_intervals);
661    let mut interval_end = Vec::with_capacity(n_intervals);
662    let mut n_deaths = vec![0.0; n_intervals];
663    let mut n_censored = vec![0.0; n_intervals];
664    for i in 0..n_intervals {
665        interval_start.push(breaks[i]);
666        interval_end.push(breaks[i + 1]);
667    }
668    for i in 0..n {
669        let t = time[i];
670        for j in 0..n_intervals {
671            if t >= breaks[j] && t < breaks[j + 1] {
672                if status[i] == 1 {
673                    n_deaths[j] += 1.0;
674                } else {
675                    n_censored[j] += 1.0;
676                }
677                break;
678            }
679        }
680    }
681    let mut n_at_risk = Vec::with_capacity(n_intervals);
682    let mut remaining = n as f64;
683    for j in 0..n_intervals {
684        n_at_risk.push(remaining);
685        remaining -= n_deaths[j] + n_censored[j];
686    }
687    let n_effective: Vec<f64> = (0..n_intervals)
688        .map(|j| n_at_risk[j] - n_censored[j] / 2.0)
689        .collect();
690    let hazard: Vec<f64> = (0..n_intervals)
691        .map(|j| {
692            if n_effective[j] > 0.0 {
693                n_deaths[j] / n_effective[j]
694            } else {
695                0.0
696            }
697        })
698        .collect();
699    let mut survival = Vec::with_capacity(n_intervals);
700    let mut se_survival = Vec::with_capacity(n_intervals);
701    let mut surv = 1.0;
702    let mut var_sum = 0.0;
703    for j in 0..n_intervals {
704        surv *= 1.0 - hazard[j];
705        survival.push(surv);
706        if n_effective[j] > 0.0 && n_effective[j] > n_deaths[j] {
707            var_sum += n_deaths[j] / (n_effective[j] * (n_effective[j] - n_deaths[j]));
708        }
709        se_survival.push(surv * var_sum.sqrt());
710    }
711    LifeTableResult {
712        interval_start,
713        interval_end,
714        n_at_risk,
715        n_deaths,
716        n_censored,
717        n_effective,
718        hazard,
719        survival,
720        se_survival,
721    }
722}
723#[pyfunction]
724pub fn life_table(time: Vec<f64>, status: Vec<i32>, breaks: Vec<f64>) -> PyResult<LifeTableResult> {
725    Ok(compute_life_table(&time, &status, &breaks))
726}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731
732    #[test]
733    fn test_compute_landmark_basic() {
734        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
735        let status = vec![1, 0, 1, 0, 1];
736        let landmark_time = 2.0;
737
738        let result = compute_landmark(&time, &status, landmark_time);
739
740        assert_eq!(result.landmark_time, 2.0);
741        assert_eq!(result.n_at_risk, 3);
742        assert_eq!(result.n_excluded, 2);
743        assert_eq!(result.time.len(), 3);
744    }
745
746    #[test]
747    fn test_compute_landmark_all_excluded() {
748        let time = vec![1.0, 2.0, 3.0];
749        let status = vec![1, 1, 1];
750        let landmark_time = 5.0;
751
752        let result = compute_landmark(&time, &status, landmark_time);
753
754        assert_eq!(result.n_at_risk, 0);
755        assert_eq!(result.n_excluded, 3);
756    }
757
758    #[test]
759    fn test_compute_landmark_none_excluded() {
760        let time = vec![5.0, 6.0, 7.0];
761        let status = vec![1, 0, 1];
762        let landmark_time = 1.0;
763
764        let result = compute_landmark(&time, &status, landmark_time);
765
766        assert_eq!(result.n_at_risk, 3);
767        assert_eq!(result.n_excluded, 0);
768        assert_eq!(result.time[0], 4.0);
769    }
770
771    #[test]
772    fn test_compute_landmarks_parallel() {
773        let time = vec![1.0, 2.0, 3.0, 4.0, 5.0];
774        let status = vec![1, 0, 1, 0, 1];
775        let landmarks = vec![1.0, 2.0, 3.0];
776
777        let results = compute_landmarks_parallel(&time, &status, &landmarks);
778
779        assert_eq!(results.len(), 3);
780        assert!(results[0].n_at_risk >= results[1].n_at_risk);
781        assert!(results[1].n_at_risk >= results[2].n_at_risk);
782    }
783
784    #[test]
785    fn test_compute_life_table_basic() {
786        let time = vec![1.5, 2.5, 3.5, 4.5, 5.5];
787        let status = vec![1, 1, 0, 1, 0];
788        let breaks = vec![0.0, 2.0, 4.0, 6.0];
789
790        let result = compute_life_table(&time, &status, &breaks);
791
792        assert_eq!(result.interval_start.len(), 3);
793        assert_eq!(result.survival.len(), 3);
794        assert!(result.survival.iter().all(|&s| (0.0..=1.0).contains(&s)));
795    }
796
797    #[test]
798    fn test_compute_life_table_no_events() {
799        let time = vec![1.5, 3.5, 5.5];
800        let status = vec![0, 0, 0];
801        let breaks = vec![0.0, 2.0, 4.0, 6.0];
802
803        let result = compute_life_table(&time, &status, &breaks);
804
805        assert_eq!(result.interval_start.len(), 3);
806        assert!(result.n_deaths.iter().all(|&d| d == 0.0));
807        assert!(result.survival.iter().all(|&s| s == 1.0));
808    }
809
810    #[test]
811    fn test_landmark_result_new() {
812        let result = LandmarkResult::new(2.0, 5, 3, vec![1.0, 2.0], vec![1, 0], vec![3, 4]);
813
814        assert_eq!(result.landmark_time, 2.0);
815        assert_eq!(result.n_at_risk, 5);
816        assert_eq!(result.n_excluded, 3);
817        assert_eq!(result.time.len(), 2);
818    }
819}