Skip to main content

survival/specialized/
pyears_summary.rs

1use crate::constants::z_score_for_confidence;
2use pyo3::prelude::*;
3use std::fmt;
4
5#[derive(Debug, Clone)]
6#[pyclass(str)]
7pub struct PyearsSummary {
8    #[pyo3(get)]
9    pub total_person_years: f64,
10    #[pyo3(get)]
11    pub total_events: f64,
12    #[pyo3(get)]
13    pub total_expected: f64,
14    #[pyo3(get)]
15    pub n_observations: f64,
16    #[pyo3(get)]
17    pub offtable: f64,
18    #[pyo3(get)]
19    pub observed_rate: f64,
20    #[pyo3(get)]
21    pub expected_rate: f64,
22    #[pyo3(get)]
23    pub smr: f64,
24    #[pyo3(get)]
25    pub sir: f64,
26}
27
28impl fmt::Display for PyearsSummary {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        write!(
31            f,
32            "PyearsSummary(person_years={:.2}, events={:.0}, expected={:.2}, SMR={:.3})",
33            self.total_person_years, self.total_events, self.total_expected, self.smr
34        )
35    }
36}
37
38#[pymethods]
39impl PyearsSummary {
40    pub fn to_table(&self) -> String {
41        let mut table = String::new();
42        table.push_str("Person-Years Summary\n");
43        table.push_str("====================\n\n");
44        table.push_str(&format!(
45            "Total person-years: {:>12.2}\n",
46            self.total_person_years
47        ));
48        table.push_str(&format!(
49            "Total observations: {:>12.0}\n",
50            self.n_observations
51        ));
52        table.push_str(&format!("Off-table:          {:>12.2}\n", self.offtable));
53        table.push('\n');
54        table.push_str(&format!(
55            "Observed events:    {:>12.0}\n",
56            self.total_events
57        ));
58        table.push_str(&format!(
59            "Expected events:    {:>12.2}\n",
60            self.total_expected
61        ));
62        table.push('\n');
63        table.push_str(&format!(
64            "Observed rate:      {:>12.6}\n",
65            self.observed_rate
66        ));
67        table.push_str(&format!(
68            "Expected rate:      {:>12.6}\n",
69            self.expected_rate
70        ));
71        table.push('\n');
72        table.push_str(&format!("SMR (O/E):          {:>12.3}\n", self.smr));
73        table.push_str(&format!("SIR (O/E):          {:>12.3}\n", self.sir));
74        table
75    }
76}
77
78#[pyfunction]
79pub fn summary_pyears(
80    pyears: Vec<f64>,
81    pn: Vec<f64>,
82    pcount: Vec<f64>,
83    pexpect: Vec<f64>,
84    offtable: f64,
85) -> PyResult<PyearsSummary> {
86    let total_person_years: f64 = pyears.iter().sum();
87    let total_events: f64 = pcount.iter().sum();
88    let total_expected: f64 = pexpect.iter().sum();
89    let n_observations: f64 = pn.iter().sum();
90
91    let observed_rate = if total_person_years > 0.0 {
92        total_events / total_person_years
93    } else {
94        0.0
95    };
96
97    let expected_rate = if total_person_years > 0.0 {
98        total_expected / total_person_years
99    } else {
100        0.0
101    };
102
103    let smr = if total_expected > 0.0 {
104        total_events / total_expected
105    } else {
106        f64::NAN
107    };
108
109    let sir = smr;
110
111    Ok(PyearsSummary {
112        total_person_years,
113        total_events,
114        total_expected,
115        n_observations,
116        offtable,
117        observed_rate,
118        expected_rate,
119        smr,
120        sir,
121    })
122}
123
124#[derive(Debug, Clone)]
125#[pyclass(str)]
126pub struct PyearsCell {
127    #[pyo3(get)]
128    pub person_years: f64,
129    #[pyo3(get)]
130    pub n: f64,
131    #[pyo3(get)]
132    pub events: f64,
133    #[pyo3(get)]
134    pub expected: f64,
135    #[pyo3(get)]
136    pub rate: f64,
137    #[pyo3(get)]
138    pub smr: f64,
139}
140
141impl fmt::Display for PyearsCell {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        write!(
144            f,
145            "PyearsCell(py={:.2}, events={:.0}, expected={:.2})",
146            self.person_years, self.events, self.expected
147        )
148    }
149}
150
151#[pyfunction]
152pub fn pyears_by_cell(
153    pyears: Vec<f64>,
154    pn: Vec<f64>,
155    pcount: Vec<f64>,
156    pexpect: Vec<f64>,
157) -> PyResult<Vec<PyearsCell>> {
158    let n = pyears.len();
159    let mut cells = Vec::with_capacity(n);
160
161    for i in 0..n {
162        let py = pyears[i];
163        let events = pcount[i];
164        let expected = pexpect[i];
165
166        let rate = if py > 0.0 { events / py } else { 0.0 };
167        let smr = if expected > 0.0 {
168            events / expected
169        } else {
170            f64::NAN
171        };
172
173        cells.push(PyearsCell {
174            person_years: py,
175            n: pn[i],
176            events,
177            expected,
178            rate,
179            smr,
180        });
181    }
182
183    Ok(cells)
184}
185
186#[pyfunction]
187pub fn pyears_ci(observed: f64, expected: f64, conf_level: f64) -> PyResult<(f64, f64, f64)> {
188    let smr = if expected > 0.0 {
189        observed / expected
190    } else {
191        f64::NAN
192    };
193
194    let z = z_score_for_confidence(conf_level);
195
196    let se_log = if observed > 0.0 {
197        1.0 / observed.sqrt()
198    } else {
199        f64::INFINITY
200    };
201
202    let lower = if observed > 0.0 {
203        smr * (-z * se_log).exp()
204    } else {
205        0.0
206    };
207
208    let upper = smr * (z * se_log).exp();
209
210    Ok((smr, lower, upper))
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_summary_pyears() {
219        let pyears = vec![100.0, 200.0, 150.0];
220        let pn = vec![50.0, 80.0, 60.0];
221        let pcount = vec![5.0, 10.0, 7.0];
222        let pexpect = vec![4.0, 8.0, 6.0];
223        let offtable = 5.0;
224
225        let summary = summary_pyears(pyears, pn, pcount, pexpect, offtable).unwrap();
226
227        assert!((summary.total_person_years - 450.0).abs() < 1e-10);
228        assert!((summary.total_events - 22.0).abs() < 1e-10);
229        assert!((summary.total_expected - 18.0).abs() < 1e-10);
230        assert!((summary.smr - 22.0 / 18.0).abs() < 1e-10);
231    }
232
233    #[test]
234    fn test_pyears_ci() {
235        let (smr, lower, upper) = pyears_ci(20.0, 10.0, 0.95).unwrap();
236
237        assert!((smr - 2.0).abs() < 1e-10);
238        assert!(lower < smr);
239        assert!(upper > smr);
240    }
241}