Skip to main content

survival/surv_analysis/
semi_markov.rs

1use pyo3::prelude::*;
2use std::collections::HashMap;
3
4use crate::utilities::statistical::erf;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7#[pyclass(eq, eq_int, from_py_object)]
8pub enum SojournDistribution {
9    Exponential,
10    Weibull,
11    LogNormal,
12    Gamma,
13    GeneralizedGamma,
14}
15
16#[pymethods]
17impl SojournDistribution {
18    fn __repr__(&self) -> String {
19        match self {
20            SojournDistribution::Exponential => "SojournDistribution.Exponential".to_string(),
21            SojournDistribution::Weibull => "SojournDistribution.Weibull".to_string(),
22            SojournDistribution::LogNormal => "SojournDistribution.LogNormal".to_string(),
23            SojournDistribution::Gamma => "SojournDistribution.Gamma".to_string(),
24            SojournDistribution::GeneralizedGamma => {
25                "SojournDistribution.GeneralizedGamma".to_string()
26            }
27        }
28    }
29}
30
31#[derive(Debug, Clone)]
32#[pyclass(from_py_object)]
33pub struct SemiMarkovConfig {
34    #[pyo3(get, set)]
35    pub n_states: usize,
36    #[pyo3(get, set)]
37    pub state_names: Vec<String>,
38    #[pyo3(get, set)]
39    pub sojourn_distributions: Vec<SojournDistribution>,
40    #[pyo3(get, set)]
41    pub absorbing_states: Vec<usize>,
42    #[pyo3(get, set)]
43    pub max_iter: usize,
44    #[pyo3(get, set)]
45    pub tol: f64,
46}
47
48#[pymethods]
49impl SemiMarkovConfig {
50    #[new]
51    #[pyo3(signature = (n_states, state_names=None, sojourn_distributions=None, absorbing_states=None, max_iter=100, tol=1e-6))]
52    pub fn new(
53        n_states: usize,
54        state_names: Option<Vec<String>>,
55        sojourn_distributions: Option<Vec<SojournDistribution>>,
56        absorbing_states: Option<Vec<usize>>,
57        max_iter: usize,
58        tol: f64,
59    ) -> Self {
60        let state_names =
61            state_names.unwrap_or_else(|| (0..n_states).map(|i| format!("State_{}", i)).collect());
62
63        let sojourn_distributions =
64            sojourn_distributions.unwrap_or_else(|| vec![SojournDistribution::Weibull; n_states]);
65
66        let absorbing_states = absorbing_states.unwrap_or_else(|| vec![n_states - 1]);
67
68        Self {
69            n_states,
70            state_names,
71            sojourn_distributions,
72            absorbing_states,
73            max_iter,
74            tol,
75        }
76    }
77}
78
79#[derive(Debug, Clone)]
80#[pyclass(from_py_object)]
81pub struct SojournTimeParams {
82    #[pyo3(get)]
83    pub distribution: SojournDistribution,
84    #[pyo3(get)]
85    pub shape: f64,
86    #[pyo3(get)]
87    pub scale: f64,
88    #[pyo3(get)]
89    pub location: f64,
90    #[pyo3(get)]
91    pub mean: f64,
92    #[pyo3(get)]
93    pub variance: f64,
94    #[pyo3(get)]
95    pub median: f64,
96}
97
98#[pymethods]
99impl SojournTimeParams {
100    fn __repr__(&self) -> String {
101        format!(
102            "SojournTimeParams(dist={:?}, mean={:.3}, var={:.3})",
103            self.distribution, self.mean, self.variance
104        )
105    }
106}
107
108#[derive(Debug, Clone)]
109#[pyclass(from_py_object)]
110pub struct SemiMarkovResult {
111    #[pyo3(get)]
112    pub transition_probs: HashMap<String, f64>,
113    #[pyo3(get)]
114    pub sojourn_params: Vec<SojournTimeParams>,
115    #[pyo3(get)]
116    pub state_occupation_probs: Vec<Vec<f64>>,
117    #[pyo3(get)]
118    pub time_points: Vec<f64>,
119    #[pyo3(get)]
120    pub mean_sojourn_times: Vec<f64>,
121    #[pyo3(get)]
122    pub n_transitions: HashMap<String, usize>,
123    #[pyo3(get)]
124    pub log_likelihood: f64,
125    #[pyo3(get)]
126    pub aic: f64,
127    #[pyo3(get)]
128    pub bic: f64,
129}
130
131#[pymethods]
132impl SemiMarkovResult {
133    fn __repr__(&self) -> String {
134        format!(
135            "SemiMarkovResult(n_states={}, ll={:.2}, aic={:.2})",
136            self.sojourn_params.len(),
137            self.log_likelihood,
138            self.aic
139        )
140    }
141
142    fn get_transition_prob(&self, from_state: usize, to_state: usize) -> f64 {
143        let key = format!("{}_{}", from_state, to_state);
144        *self.transition_probs.get(&key).unwrap_or(&0.0)
145    }
146
147    fn predict_state_at_time(&self, time: f64) -> Vec<f64> {
148        if self.time_points.is_empty() {
149            return vec![0.0; self.sojourn_params.len()];
150        }
151
152        let idx = self
153            .time_points
154            .iter()
155            .position(|&t| t >= time)
156            .unwrap_or(self.time_points.len() - 1);
157
158        self.state_occupation_probs[idx].clone()
159    }
160}
161
162fn weibull_pdf(t: f64, shape: f64, scale: f64) -> f64 {
163    if t <= 0.0 || shape <= 0.0 || scale <= 0.0 {
164        return 0.0;
165    }
166    (shape / scale) * (t / scale).powf(shape - 1.0) * (-(t / scale).powf(shape)).exp()
167}
168
169fn weibull_cdf(t: f64, shape: f64, scale: f64) -> f64 {
170    if t <= 0.0 {
171        return 0.0;
172    }
173    1.0 - (-(t / scale).powf(shape)).exp()
174}
175
176fn weibull_survival(t: f64, shape: f64, scale: f64) -> f64 {
177    if t <= 0.0 {
178        return 1.0;
179    }
180    (-(t / scale).powf(shape)).exp()
181}
182
183fn lognormal_pdf(t: f64, mu: f64, sigma: f64) -> f64 {
184    if t <= 0.0 || sigma <= 0.0 {
185        return 0.0;
186    }
187    let log_t = t.ln();
188    (1.0 / (t * sigma * (2.0 * std::f64::consts::PI).sqrt()))
189        * (-0.5 * ((log_t - mu) / sigma).powi(2)).exp()
190}
191
192fn lognormal_cdf(t: f64, mu: f64, sigma: f64) -> f64 {
193    if t <= 0.0 {
194        return 0.0;
195    }
196    0.5 * (1.0 + erf((t.ln() - mu) / (sigma * std::f64::consts::SQRT_2)))
197}
198
199fn gamma_pdf(t: f64, shape: f64, rate: f64) -> f64 {
200    if t <= 0.0 || shape <= 0.0 || rate <= 0.0 {
201        return 0.0;
202    }
203    let ln_gamma = ln_gamma_fn(shape);
204    (shape * rate.ln() + (shape - 1.0) * t.ln() - rate * t - ln_gamma).exp()
205}
206
207fn ln_gamma_fn(x: f64) -> f64 {
208    let coeffs = [
209        76.18009172947146,
210        -86.50532032941677,
211        24.01409824083091,
212        -1.231739572450155,
213        0.1208650973866179e-2,
214        -0.5395239384953e-5,
215    ];
216
217    let tmp = x + 5.5;
218    let tmp = tmp - (x + 0.5) * tmp.ln();
219    let mut ser = 1.000000000190015;
220    for (i, &coeff) in coeffs.iter().enumerate() {
221        ser += coeff / (x + 1.0 + i as f64);
222    }
223    -tmp + (2.5066282746310005 * ser / x).ln()
224}
225
226fn fit_weibull_mle(times: &[f64]) -> (f64, f64) {
227    if times.is_empty() {
228        return (1.0, 1.0);
229    }
230
231    let n = times.len() as f64;
232
233    let mut shape = 1.0;
234    for _ in 0..50 {
235        let sum_t_k: f64 = times.iter().map(|&t| t.powf(shape)).sum();
236        let sum_t_k_ln_t: f64 = times.iter().map(|&t| t.powf(shape) * t.ln()).sum();
237        let sum_ln_t: f64 = times.iter().map(|&t| t.ln()).sum();
238
239        if sum_t_k.abs() < 1e-10 {
240            break;
241        }
242
243        let f = sum_t_k_ln_t / sum_t_k - 1.0 / shape - sum_ln_t / n;
244        let df = -1.0 / shape.powi(2);
245
246        let new_shape = shape - f / df;
247        if (new_shape - shape).abs() < 1e-6 {
248            shape = new_shape.max(0.1);
249            break;
250        }
251        shape = new_shape.max(0.1);
252    }
253
254    let scale = (times.iter().map(|&t| t.powf(shape)).sum::<f64>() / n).powf(1.0 / shape);
255
256    (shape, scale.max(1e-10))
257}
258
259fn fit_lognormal_mle(times: &[f64]) -> (f64, f64) {
260    if times.is_empty() {
261        return (0.0, 1.0);
262    }
263
264    let log_times: Vec<f64> = times.iter().map(|&t| t.max(1e-10).ln()).collect();
265    let n = log_times.len() as f64;
266    let mu = log_times.iter().sum::<f64>() / n;
267    let sigma = (log_times.iter().map(|&lt| (lt - mu).powi(2)).sum::<f64>() / n)
268        .sqrt()
269        .max(0.01);
270
271    (mu, sigma)
272}
273
274fn fit_gamma_mle(times: &[f64]) -> (f64, f64) {
275    if times.is_empty() {
276        return (1.0, 1.0);
277    }
278
279    let n = times.len() as f64;
280    let mean = times.iter().sum::<f64>() / n;
281    let log_mean = times.iter().map(|&t| t.max(1e-10).ln()).sum::<f64>() / n;
282
283    let s = mean.ln() - log_mean;
284    let shape = if s > 0.0 {
285        (3.0 - s + ((s - 3.0).powi(2) + 24.0 * s).sqrt()) / (12.0 * s)
286    } else {
287        1.0
288    };
289
290    let rate = shape / mean;
291
292    (shape.max(0.1), rate.max(0.01))
293}
294
295#[pyfunction]
296#[pyo3(signature = (entry_times, exit_times, from_states, to_states, config))]
297pub fn fit_semi_markov(
298    entry_times: Vec<f64>,
299    exit_times: Vec<f64>,
300    from_states: Vec<i32>,
301    to_states: Vec<i32>,
302    config: &SemiMarkovConfig,
303) -> PyResult<SemiMarkovResult> {
304    let n = entry_times.len();
305    if exit_times.len() != n || from_states.len() != n || to_states.len() != n {
306        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
307            "All input vectors must have the same length",
308        ));
309    }
310
311    let sojourn_times: Vec<f64> = entry_times
312        .iter()
313        .zip(exit_times.iter())
314        .map(|(&entry, &exit)| (exit - entry).max(1e-10))
315        .collect();
316
317    let mut transition_counts: HashMap<String, usize> = HashMap::new();
318    let mut state_counts: Vec<usize> = vec![0; config.n_states];
319
320    for i in 0..n {
321        let from = from_states[i] as usize;
322        let to = to_states[i] as usize;
323        if from < config.n_states && to < config.n_states {
324            let key = format!("{}_{}", from, to);
325            *transition_counts.entry(key).or_insert(0) += 1;
326            state_counts[from] += 1;
327        }
328    }
329
330    let mut transition_probs: HashMap<String, f64> = HashMap::new();
331    for (from, &state_count) in state_counts.iter().enumerate().take(config.n_states) {
332        for to in 0..config.n_states {
333            let key = format!("{}_{}", from, to);
334            let count = *transition_counts.get(&key).unwrap_or(&0);
335            let prob = if state_count > 0 {
336                count as f64 / state_count as f64
337            } else {
338                0.0
339            };
340            transition_probs.insert(key, prob);
341        }
342    }
343
344    let mut sojourn_params: Vec<SojournTimeParams> = Vec::new();
345    let mut log_likelihood = 0.0;
346    let mut n_params = 0;
347
348    for state in 0..config.n_states {
349        let state_sojourn: Vec<f64> = (0..n)
350            .filter(|&i| from_states[i] as usize == state)
351            .map(|i| sojourn_times[i])
352            .collect();
353
354        let dist = config.sojourn_distributions[state];
355        let (shape, scale, location, mean, variance, median) = if state_sojourn.is_empty() {
356            (1.0, 1.0, 0.0, 1.0, 1.0, 1.0)
357        } else {
358            match dist {
359                SojournDistribution::Exponential => {
360                    let mean = state_sojourn.iter().sum::<f64>() / state_sojourn.len() as f64;
361                    let rate = 1.0 / mean;
362                    for &t in &state_sojourn {
363                        log_likelihood += rate.ln() - rate * t;
364                    }
365                    n_params += 1;
366                    (1.0, mean, 0.0, mean, mean.powi(2), mean * 2.0_f64.ln())
367                }
368                SojournDistribution::Weibull => {
369                    let (shape, scale) = fit_weibull_mle(&state_sojourn);
370                    for &t in &state_sojourn {
371                        let pdf = weibull_pdf(t, shape, scale);
372                        if pdf > 1e-300 {
373                            log_likelihood += pdf.ln();
374                        }
375                    }
376                    n_params += 2;
377                    let mean = scale * ln_gamma_fn(1.0 + 1.0 / shape).exp();
378                    let var = scale.powi(2)
379                        * (ln_gamma_fn(1.0 + 2.0 / shape).exp()
380                            - ln_gamma_fn(1.0 + 1.0 / shape).exp().powi(2));
381                    let median = scale * 2.0_f64.ln().powf(1.0 / shape);
382                    (shape, scale, 0.0, mean, var, median)
383                }
384                SojournDistribution::LogNormal => {
385                    let (mu, sigma) = fit_lognormal_mle(&state_sojourn);
386                    for &t in &state_sojourn {
387                        let pdf = lognormal_pdf(t, mu, sigma);
388                        if pdf > 1e-300 {
389                            log_likelihood += pdf.ln();
390                        }
391                    }
392                    n_params += 2;
393                    let mean = (mu + sigma.powi(2) / 2.0).exp();
394                    let var = (sigma.powi(2).exp() - 1.0) * (2.0 * mu + sigma.powi(2)).exp();
395                    let median = mu.exp();
396                    (sigma, mu.exp(), mu, mean, var, median)
397                }
398                SojournDistribution::Gamma => {
399                    let (shape, rate) = fit_gamma_mle(&state_sojourn);
400                    for &t in &state_sojourn {
401                        let pdf = gamma_pdf(t, shape, rate);
402                        if pdf > 1e-300 {
403                            log_likelihood += pdf.ln();
404                        }
405                    }
406                    n_params += 2;
407                    let mean = shape / rate;
408                    let var = shape / rate.powi(2);
409                    let median = mean * (1.0 - 1.0 / (9.0 * shape)).powi(3);
410                    (shape, 1.0 / rate, 0.0, mean, var, median)
411                }
412                SojournDistribution::GeneralizedGamma => {
413                    let (shape, scale) = fit_weibull_mle(&state_sojourn);
414                    for &t in &state_sojourn {
415                        let pdf = weibull_pdf(t, shape, scale);
416                        if pdf > 1e-300 {
417                            log_likelihood += pdf.ln();
418                        }
419                    }
420                    n_params += 3;
421                    let mean = scale * ln_gamma_fn(1.0 + 1.0 / shape).exp();
422                    let var = scale.powi(2)
423                        * (ln_gamma_fn(1.0 + 2.0 / shape).exp()
424                            - ln_gamma_fn(1.0 + 1.0 / shape).exp().powi(2));
425                    let median = scale * 2.0_f64.ln().powf(1.0 / shape);
426                    (shape, scale, 0.0, mean, var, median)
427                }
428            }
429        };
430
431        sojourn_params.push(SojournTimeParams {
432            distribution: dist,
433            shape,
434            scale,
435            location,
436            mean,
437            variance,
438            median,
439        });
440    }
441
442    let max_time = exit_times.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
443    let n_time_points = 100;
444    let time_points: Vec<f64> = (0..=n_time_points)
445        .map(|i| i as f64 * max_time / n_time_points as f64)
446        .collect();
447
448    let mut state_occupation_probs: Vec<Vec<f64>> = Vec::new();
449    for &t in &time_points {
450        let mut probs = vec![0.0; config.n_states];
451        if t == 0.0 {
452            probs[0] = 1.0;
453        } else {
454            for state in 0..config.n_states {
455                if config.absorbing_states.contains(&state) {
456                    let mut absorb_prob = 0.0;
457                    for (from, params) in sojourn_params.iter().enumerate().take(config.n_states) {
458                        if !config.absorbing_states.contains(&from) {
459                            let key = format!("{}_{}", from, state);
460                            let trans_prob = *transition_probs.get(&key).unwrap_or(&0.0);
461                            let cdf = match params.distribution {
462                                SojournDistribution::Weibull
463                                | SojournDistribution::GeneralizedGamma => {
464                                    weibull_cdf(t, params.shape, params.scale)
465                                }
466                                SojournDistribution::Exponential => 1.0 - (-t / params.scale).exp(),
467                                SojournDistribution::LogNormal => {
468                                    lognormal_cdf(t, params.location, params.shape)
469                                }
470                                SojournDistribution::Gamma => {
471                                    1.0 - (-t * params.shape / params.scale).exp()
472                                }
473                            };
474                            absorb_prob += trans_prob * cdf;
475                        }
476                    }
477                    probs[state] = absorb_prob.min(1.0);
478                } else {
479                    let params = &sojourn_params[state];
480                    let surv = match params.distribution {
481                        SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
482                            weibull_survival(t, params.shape, params.scale)
483                        }
484                        SojournDistribution::Exponential => (-t / params.scale).exp(),
485                        SojournDistribution::LogNormal => {
486                            1.0 - lognormal_cdf(t, params.location, params.shape)
487                        }
488                        SojournDistribution::Gamma => (-t * params.shape / params.scale).exp(),
489                    };
490                    probs[state] = surv * (1.0 - probs.iter().sum::<f64>()).max(0.0);
491                }
492            }
493        }
494
495        let sum: f64 = probs.iter().sum();
496        if sum > 0.0 {
497            for p in &mut probs {
498                *p /= sum;
499            }
500        }
501        state_occupation_probs.push(probs);
502    }
503
504    let mean_sojourn_times: Vec<f64> = sojourn_params.iter().map(|p| p.mean).collect();
505
506    let n_obs = n as f64;
507    let aic = -2.0 * log_likelihood + 2.0 * n_params as f64;
508    let bic = -2.0 * log_likelihood + (n_params as f64) * n_obs.ln();
509
510    Ok(SemiMarkovResult {
511        transition_probs,
512        sojourn_params,
513        state_occupation_probs,
514        time_points,
515        mean_sojourn_times,
516        n_transitions: transition_counts,
517        log_likelihood,
518        aic,
519        bic,
520    })
521}
522
523#[derive(Debug, Clone)]
524#[pyclass(from_py_object)]
525pub struct SemiMarkovPrediction {
526    #[pyo3(get)]
527    pub state_probs: Vec<Vec<f64>>,
528    #[pyo3(get)]
529    pub time_points: Vec<f64>,
530    #[pyo3(get)]
531    pub expected_sojourn: Vec<f64>,
532    #[pyo3(get)]
533    pub transition_hazards: HashMap<String, Vec<f64>>,
534}
535
536#[pyfunction]
537#[pyo3(signature = (model, current_state, time_in_state, prediction_times))]
538pub fn predict_semi_markov(
539    model: &SemiMarkovResult,
540    current_state: usize,
541    time_in_state: f64,
542    prediction_times: Vec<f64>,
543) -> PyResult<SemiMarkovPrediction> {
544    let n_states = model.sojourn_params.len();
545    if current_state >= n_states {
546        return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
547            "current_state must be less than number of states",
548        ));
549    }
550
551    let params = &model.sojourn_params[current_state];
552    let current_survival = match params.distribution {
553        SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
554            weibull_survival(time_in_state, params.shape, params.scale)
555        }
556        SojournDistribution::Exponential => (-time_in_state / params.scale).exp(),
557        _ => 1.0,
558    };
559
560    let mut state_probs: Vec<Vec<f64>> = Vec::new();
561    let mut transition_hazards: HashMap<String, Vec<f64>> = HashMap::new();
562
563    for to_state in 0..n_states {
564        let key = format!("{}_{}", current_state, to_state);
565        transition_hazards.insert(key.clone(), Vec::new());
566    }
567
568    for &t in &prediction_times {
569        let total_time = time_in_state + t;
570        let mut probs = vec![0.0; n_states];
571
572        let future_survival = match params.distribution {
573            SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
574                weibull_survival(total_time, params.shape, params.scale)
575            }
576            SojournDistribution::Exponential => (-total_time / params.scale).exp(),
577            _ => 1.0,
578        };
579
580        let conditional_survival = if current_survival > 1e-10 {
581            future_survival / current_survival
582        } else {
583            0.0
584        };
585
586        probs[current_state] = conditional_survival;
587
588        let exit_prob = 1.0 - conditional_survival;
589        for (to_state, prob) in probs.iter_mut().enumerate().take(n_states) {
590            if to_state != current_state {
591                let key = format!("{}_{}", current_state, to_state);
592                let trans_prob = *model.transition_probs.get(&key).unwrap_or(&0.0);
593                *prob = exit_prob * trans_prob;
594            }
595        }
596
597        state_probs.push(probs);
598
599        for to_state in 0..n_states {
600            let key = format!("{}_{}", current_state, to_state);
601            let trans_prob = *model.transition_probs.get(&key).unwrap_or(&0.0);
602
603            let hazard = if conditional_survival > 1e-10 {
604                let pdf = match params.distribution {
605                    SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
606                        weibull_pdf(total_time, params.shape, params.scale)
607                    }
608                    SojournDistribution::Exponential => {
609                        (1.0 / params.scale) * (-total_time / params.scale).exp()
610                    }
611                    _ => 0.0,
612                };
613                trans_prob * pdf / future_survival
614            } else {
615                0.0
616            };
617
618            if let Some(hazards) = transition_hazards.get_mut(&key) {
619                hazards.push(hazard);
620            } else {
621                return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
622                    "internal error: missing transition hazard bucket",
623                ));
624            }
625        }
626    }
627
628    let expected_sojourn: Vec<f64> = model.sojourn_params.iter().map(|p| p.mean).collect();
629
630    Ok(SemiMarkovPrediction {
631        state_probs,
632        time_points: prediction_times,
633        expected_sojourn,
634        transition_hazards,
635    })
636}
637
638#[cfg(test)]
639mod tests {
640    use super::*;
641
642    #[test]
643    fn test_semi_markov_config() {
644        let config = SemiMarkovConfig::new(3, None, None, None, 100, 1e-6);
645        assert_eq!(config.n_states, 3);
646        assert_eq!(config.state_names.len(), 3);
647        assert_eq!(config.sojourn_distributions.len(), 3);
648    }
649
650    #[test]
651    fn test_weibull_functions() {
652        let pdf = weibull_pdf(1.0, 2.0, 1.0);
653        assert!(pdf > 0.0 && pdf < 1.0);
654
655        let cdf = weibull_cdf(1.0, 2.0, 1.0);
656        assert!(cdf > 0.0 && cdf < 1.0);
657
658        let surv = weibull_survival(1.0, 2.0, 1.0);
659        assert!((surv + cdf - 1.0).abs() < 1e-10);
660    }
661
662    #[test]
663    fn test_fit_semi_markov() {
664        let entry_times = vec![0.0, 1.0, 2.0, 3.0, 0.0, 1.5, 2.5, 3.5];
665        let exit_times = vec![1.0, 2.0, 3.0, 4.0, 1.5, 2.5, 3.5, 5.0];
666        let from_states = vec![0, 0, 1, 1, 0, 0, 1, 1];
667        let to_states = vec![1, 1, 2, 2, 1, 1, 2, 2];
668
669        let config = SemiMarkovConfig::new(3, None, None, Some(vec![2]), 100, 1e-6);
670        let result =
671            fit_semi_markov(entry_times, exit_times, from_states, to_states, &config).unwrap();
672
673        assert_eq!(result.sojourn_params.len(), 3);
674        assert!(!result.transition_probs.is_empty());
675        assert!(result.log_likelihood.is_finite());
676    }
677
678    #[test]
679    fn test_predict_semi_markov() {
680        let entry_times = vec![0.0, 1.0, 2.0, 3.0];
681        let exit_times = vec![1.0, 2.0, 3.0, 4.0];
682        let from_states = vec![0, 0, 1, 1];
683        let to_states = vec![1, 1, 2, 2];
684
685        let config = SemiMarkovConfig::new(3, None, None, Some(vec![2]), 100, 1e-6);
686        let model =
687            fit_semi_markov(entry_times, exit_times, from_states, to_states, &config).unwrap();
688
689        let prediction = predict_semi_markov(&model, 0, 0.5, vec![0.5, 1.0, 1.5, 2.0]).unwrap();
690
691        assert_eq!(prediction.state_probs.len(), 4);
692        assert_eq!(prediction.time_points.len(), 4);
693    }
694}