1use pyo3::prelude::*;
2use std::collections::HashMap;
3
4use crate::utilities::statistical::erf;
5
6#[derive(Debug, Clone, Copy, PartialEq)]
7#[pyclass(eq, eq_int, from_py_object)]
8pub enum SojournDistribution {
9 Exponential,
10 Weibull,
11 LogNormal,
12 Gamma,
13 GeneralizedGamma,
14}
15
16#[pymethods]
17impl SojournDistribution {
18 fn __repr__(&self) -> String {
19 match self {
20 SojournDistribution::Exponential => "SojournDistribution.Exponential".to_string(),
21 SojournDistribution::Weibull => "SojournDistribution.Weibull".to_string(),
22 SojournDistribution::LogNormal => "SojournDistribution.LogNormal".to_string(),
23 SojournDistribution::Gamma => "SojournDistribution.Gamma".to_string(),
24 SojournDistribution::GeneralizedGamma => {
25 "SojournDistribution.GeneralizedGamma".to_string()
26 }
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
32#[pyclass(from_py_object)]
33pub struct SemiMarkovConfig {
34 #[pyo3(get, set)]
35 pub n_states: usize,
36 #[pyo3(get, set)]
37 pub state_names: Vec<String>,
38 #[pyo3(get, set)]
39 pub sojourn_distributions: Vec<SojournDistribution>,
40 #[pyo3(get, set)]
41 pub absorbing_states: Vec<usize>,
42 #[pyo3(get, set)]
43 pub max_iter: usize,
44 #[pyo3(get, set)]
45 pub tol: f64,
46}
47
48#[pymethods]
49impl SemiMarkovConfig {
50 #[new]
51 #[pyo3(signature = (n_states, state_names=None, sojourn_distributions=None, absorbing_states=None, max_iter=100, tol=1e-6))]
52 pub fn new(
53 n_states: usize,
54 state_names: Option<Vec<String>>,
55 sojourn_distributions: Option<Vec<SojournDistribution>>,
56 absorbing_states: Option<Vec<usize>>,
57 max_iter: usize,
58 tol: f64,
59 ) -> Self {
60 let state_names =
61 state_names.unwrap_or_else(|| (0..n_states).map(|i| format!("State_{}", i)).collect());
62
63 let sojourn_distributions =
64 sojourn_distributions.unwrap_or_else(|| vec![SojournDistribution::Weibull; n_states]);
65
66 let absorbing_states = absorbing_states.unwrap_or_else(|| vec![n_states - 1]);
67
68 Self {
69 n_states,
70 state_names,
71 sojourn_distributions,
72 absorbing_states,
73 max_iter,
74 tol,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
80#[pyclass(from_py_object)]
81pub struct SojournTimeParams {
82 #[pyo3(get)]
83 pub distribution: SojournDistribution,
84 #[pyo3(get)]
85 pub shape: f64,
86 #[pyo3(get)]
87 pub scale: f64,
88 #[pyo3(get)]
89 pub location: f64,
90 #[pyo3(get)]
91 pub mean: f64,
92 #[pyo3(get)]
93 pub variance: f64,
94 #[pyo3(get)]
95 pub median: f64,
96}
97
98#[pymethods]
99impl SojournTimeParams {
100 fn __repr__(&self) -> String {
101 format!(
102 "SojournTimeParams(dist={:?}, mean={:.3}, var={:.3})",
103 self.distribution, self.mean, self.variance
104 )
105 }
106}
107
108#[derive(Debug, Clone)]
109#[pyclass(from_py_object)]
110pub struct SemiMarkovResult {
111 #[pyo3(get)]
112 pub transition_probs: HashMap<String, f64>,
113 #[pyo3(get)]
114 pub sojourn_params: Vec<SojournTimeParams>,
115 #[pyo3(get)]
116 pub state_occupation_probs: Vec<Vec<f64>>,
117 #[pyo3(get)]
118 pub time_points: Vec<f64>,
119 #[pyo3(get)]
120 pub mean_sojourn_times: Vec<f64>,
121 #[pyo3(get)]
122 pub n_transitions: HashMap<String, usize>,
123 #[pyo3(get)]
124 pub log_likelihood: f64,
125 #[pyo3(get)]
126 pub aic: f64,
127 #[pyo3(get)]
128 pub bic: f64,
129}
130
131#[pymethods]
132impl SemiMarkovResult {
133 fn __repr__(&self) -> String {
134 format!(
135 "SemiMarkovResult(n_states={}, ll={:.2}, aic={:.2})",
136 self.sojourn_params.len(),
137 self.log_likelihood,
138 self.aic
139 )
140 }
141
142 fn get_transition_prob(&self, from_state: usize, to_state: usize) -> f64 {
143 let key = format!("{}_{}", from_state, to_state);
144 *self.transition_probs.get(&key).unwrap_or(&0.0)
145 }
146
147 fn predict_state_at_time(&self, time: f64) -> Vec<f64> {
148 if self.time_points.is_empty() {
149 return vec![0.0; self.sojourn_params.len()];
150 }
151
152 let idx = self
153 .time_points
154 .iter()
155 .position(|&t| t >= time)
156 .unwrap_or(self.time_points.len() - 1);
157
158 self.state_occupation_probs[idx].clone()
159 }
160}
161
162fn weibull_pdf(t: f64, shape: f64, scale: f64) -> f64 {
163 if t <= 0.0 || shape <= 0.0 || scale <= 0.0 {
164 return 0.0;
165 }
166 (shape / scale) * (t / scale).powf(shape - 1.0) * (-(t / scale).powf(shape)).exp()
167}
168
169fn weibull_cdf(t: f64, shape: f64, scale: f64) -> f64 {
170 if t <= 0.0 {
171 return 0.0;
172 }
173 1.0 - (-(t / scale).powf(shape)).exp()
174}
175
176fn weibull_survival(t: f64, shape: f64, scale: f64) -> f64 {
177 if t <= 0.0 {
178 return 1.0;
179 }
180 (-(t / scale).powf(shape)).exp()
181}
182
183fn lognormal_pdf(t: f64, mu: f64, sigma: f64) -> f64 {
184 if t <= 0.0 || sigma <= 0.0 {
185 return 0.0;
186 }
187 let log_t = t.ln();
188 (1.0 / (t * sigma * (2.0 * std::f64::consts::PI).sqrt()))
189 * (-0.5 * ((log_t - mu) / sigma).powi(2)).exp()
190}
191
192fn lognormal_cdf(t: f64, mu: f64, sigma: f64) -> f64 {
193 if t <= 0.0 {
194 return 0.0;
195 }
196 0.5 * (1.0 + erf((t.ln() - mu) / (sigma * std::f64::consts::SQRT_2)))
197}
198
199fn gamma_pdf(t: f64, shape: f64, rate: f64) -> f64 {
200 if t <= 0.0 || shape <= 0.0 || rate <= 0.0 {
201 return 0.0;
202 }
203 let ln_gamma = ln_gamma_fn(shape);
204 (shape * rate.ln() + (shape - 1.0) * t.ln() - rate * t - ln_gamma).exp()
205}
206
207fn ln_gamma_fn(x: f64) -> f64 {
208 let coeffs = [
209 76.18009172947146,
210 -86.50532032941677,
211 24.01409824083091,
212 -1.231739572450155,
213 0.1208650973866179e-2,
214 -0.5395239384953e-5,
215 ];
216
217 let tmp = x + 5.5;
218 let tmp = tmp - (x + 0.5) * tmp.ln();
219 let mut ser = 1.000000000190015;
220 for (i, &coeff) in coeffs.iter().enumerate() {
221 ser += coeff / (x + 1.0 + i as f64);
222 }
223 -tmp + (2.5066282746310005 * ser / x).ln()
224}
225
226fn fit_weibull_mle(times: &[f64]) -> (f64, f64) {
227 if times.is_empty() {
228 return (1.0, 1.0);
229 }
230
231 let n = times.len() as f64;
232
233 let mut shape = 1.0;
234 for _ in 0..50 {
235 let sum_t_k: f64 = times.iter().map(|&t| t.powf(shape)).sum();
236 let sum_t_k_ln_t: f64 = times.iter().map(|&t| t.powf(shape) * t.ln()).sum();
237 let sum_ln_t: f64 = times.iter().map(|&t| t.ln()).sum();
238
239 if sum_t_k.abs() < 1e-10 {
240 break;
241 }
242
243 let f = sum_t_k_ln_t / sum_t_k - 1.0 / shape - sum_ln_t / n;
244 let df = -1.0 / shape.powi(2);
245
246 let new_shape = shape - f / df;
247 if (new_shape - shape).abs() < 1e-6 {
248 shape = new_shape.max(0.1);
249 break;
250 }
251 shape = new_shape.max(0.1);
252 }
253
254 let scale = (times.iter().map(|&t| t.powf(shape)).sum::<f64>() / n).powf(1.0 / shape);
255
256 (shape, scale.max(1e-10))
257}
258
259fn fit_lognormal_mle(times: &[f64]) -> (f64, f64) {
260 if times.is_empty() {
261 return (0.0, 1.0);
262 }
263
264 let log_times: Vec<f64> = times.iter().map(|&t| t.max(1e-10).ln()).collect();
265 let n = log_times.len() as f64;
266 let mu = log_times.iter().sum::<f64>() / n;
267 let sigma = (log_times.iter().map(|<| (lt - mu).powi(2)).sum::<f64>() / n)
268 .sqrt()
269 .max(0.01);
270
271 (mu, sigma)
272}
273
274fn fit_gamma_mle(times: &[f64]) -> (f64, f64) {
275 if times.is_empty() {
276 return (1.0, 1.0);
277 }
278
279 let n = times.len() as f64;
280 let mean = times.iter().sum::<f64>() / n;
281 let log_mean = times.iter().map(|&t| t.max(1e-10).ln()).sum::<f64>() / n;
282
283 let s = mean.ln() - log_mean;
284 let shape = if s > 0.0 {
285 (3.0 - s + ((s - 3.0).powi(2) + 24.0 * s).sqrt()) / (12.0 * s)
286 } else {
287 1.0
288 };
289
290 let rate = shape / mean;
291
292 (shape.max(0.1), rate.max(0.01))
293}
294
295#[pyfunction]
296#[pyo3(signature = (entry_times, exit_times, from_states, to_states, config))]
297pub fn fit_semi_markov(
298 entry_times: Vec<f64>,
299 exit_times: Vec<f64>,
300 from_states: Vec<i32>,
301 to_states: Vec<i32>,
302 config: &SemiMarkovConfig,
303) -> PyResult<SemiMarkovResult> {
304 let n = entry_times.len();
305 if exit_times.len() != n || from_states.len() != n || to_states.len() != n {
306 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
307 "All input vectors must have the same length",
308 ));
309 }
310
311 let sojourn_times: Vec<f64> = entry_times
312 .iter()
313 .zip(exit_times.iter())
314 .map(|(&entry, &exit)| (exit - entry).max(1e-10))
315 .collect();
316
317 let mut transition_counts: HashMap<String, usize> = HashMap::new();
318 let mut state_counts: Vec<usize> = vec![0; config.n_states];
319
320 for i in 0..n {
321 let from = from_states[i] as usize;
322 let to = to_states[i] as usize;
323 if from < config.n_states && to < config.n_states {
324 let key = format!("{}_{}", from, to);
325 *transition_counts.entry(key).or_insert(0) += 1;
326 state_counts[from] += 1;
327 }
328 }
329
330 let mut transition_probs: HashMap<String, f64> = HashMap::new();
331 for (from, &state_count) in state_counts.iter().enumerate().take(config.n_states) {
332 for to in 0..config.n_states {
333 let key = format!("{}_{}", from, to);
334 let count = *transition_counts.get(&key).unwrap_or(&0);
335 let prob = if state_count > 0 {
336 count as f64 / state_count as f64
337 } else {
338 0.0
339 };
340 transition_probs.insert(key, prob);
341 }
342 }
343
344 let mut sojourn_params: Vec<SojournTimeParams> = Vec::new();
345 let mut log_likelihood = 0.0;
346 let mut n_params = 0;
347
348 for state in 0..config.n_states {
349 let state_sojourn: Vec<f64> = (0..n)
350 .filter(|&i| from_states[i] as usize == state)
351 .map(|i| sojourn_times[i])
352 .collect();
353
354 let dist = config.sojourn_distributions[state];
355 let (shape, scale, location, mean, variance, median) = if state_sojourn.is_empty() {
356 (1.0, 1.0, 0.0, 1.0, 1.0, 1.0)
357 } else {
358 match dist {
359 SojournDistribution::Exponential => {
360 let mean = state_sojourn.iter().sum::<f64>() / state_sojourn.len() as f64;
361 let rate = 1.0 / mean;
362 for &t in &state_sojourn {
363 log_likelihood += rate.ln() - rate * t;
364 }
365 n_params += 1;
366 (1.0, mean, 0.0, mean, mean.powi(2), mean * 2.0_f64.ln())
367 }
368 SojournDistribution::Weibull => {
369 let (shape, scale) = fit_weibull_mle(&state_sojourn);
370 for &t in &state_sojourn {
371 let pdf = weibull_pdf(t, shape, scale);
372 if pdf > 1e-300 {
373 log_likelihood += pdf.ln();
374 }
375 }
376 n_params += 2;
377 let mean = scale * ln_gamma_fn(1.0 + 1.0 / shape).exp();
378 let var = scale.powi(2)
379 * (ln_gamma_fn(1.0 + 2.0 / shape).exp()
380 - ln_gamma_fn(1.0 + 1.0 / shape).exp().powi(2));
381 let median = scale * 2.0_f64.ln().powf(1.0 / shape);
382 (shape, scale, 0.0, mean, var, median)
383 }
384 SojournDistribution::LogNormal => {
385 let (mu, sigma) = fit_lognormal_mle(&state_sojourn);
386 for &t in &state_sojourn {
387 let pdf = lognormal_pdf(t, mu, sigma);
388 if pdf > 1e-300 {
389 log_likelihood += pdf.ln();
390 }
391 }
392 n_params += 2;
393 let mean = (mu + sigma.powi(2) / 2.0).exp();
394 let var = (sigma.powi(2).exp() - 1.0) * (2.0 * mu + sigma.powi(2)).exp();
395 let median = mu.exp();
396 (sigma, mu.exp(), mu, mean, var, median)
397 }
398 SojournDistribution::Gamma => {
399 let (shape, rate) = fit_gamma_mle(&state_sojourn);
400 for &t in &state_sojourn {
401 let pdf = gamma_pdf(t, shape, rate);
402 if pdf > 1e-300 {
403 log_likelihood += pdf.ln();
404 }
405 }
406 n_params += 2;
407 let mean = shape / rate;
408 let var = shape / rate.powi(2);
409 let median = mean * (1.0 - 1.0 / (9.0 * shape)).powi(3);
410 (shape, 1.0 / rate, 0.0, mean, var, median)
411 }
412 SojournDistribution::GeneralizedGamma => {
413 let (shape, scale) = fit_weibull_mle(&state_sojourn);
414 for &t in &state_sojourn {
415 let pdf = weibull_pdf(t, shape, scale);
416 if pdf > 1e-300 {
417 log_likelihood += pdf.ln();
418 }
419 }
420 n_params += 3;
421 let mean = scale * ln_gamma_fn(1.0 + 1.0 / shape).exp();
422 let var = scale.powi(2)
423 * (ln_gamma_fn(1.0 + 2.0 / shape).exp()
424 - ln_gamma_fn(1.0 + 1.0 / shape).exp().powi(2));
425 let median = scale * 2.0_f64.ln().powf(1.0 / shape);
426 (shape, scale, 0.0, mean, var, median)
427 }
428 }
429 };
430
431 sojourn_params.push(SojournTimeParams {
432 distribution: dist,
433 shape,
434 scale,
435 location,
436 mean,
437 variance,
438 median,
439 });
440 }
441
442 let max_time = exit_times.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
443 let n_time_points = 100;
444 let time_points: Vec<f64> = (0..=n_time_points)
445 .map(|i| i as f64 * max_time / n_time_points as f64)
446 .collect();
447
448 let mut state_occupation_probs: Vec<Vec<f64>> = Vec::new();
449 for &t in &time_points {
450 let mut probs = vec![0.0; config.n_states];
451 if t == 0.0 {
452 probs[0] = 1.0;
453 } else {
454 for state in 0..config.n_states {
455 if config.absorbing_states.contains(&state) {
456 let mut absorb_prob = 0.0;
457 for (from, params) in sojourn_params.iter().enumerate().take(config.n_states) {
458 if !config.absorbing_states.contains(&from) {
459 let key = format!("{}_{}", from, state);
460 let trans_prob = *transition_probs.get(&key).unwrap_or(&0.0);
461 let cdf = match params.distribution {
462 SojournDistribution::Weibull
463 | SojournDistribution::GeneralizedGamma => {
464 weibull_cdf(t, params.shape, params.scale)
465 }
466 SojournDistribution::Exponential => 1.0 - (-t / params.scale).exp(),
467 SojournDistribution::LogNormal => {
468 lognormal_cdf(t, params.location, params.shape)
469 }
470 SojournDistribution::Gamma => {
471 1.0 - (-t * params.shape / params.scale).exp()
472 }
473 };
474 absorb_prob += trans_prob * cdf;
475 }
476 }
477 probs[state] = absorb_prob.min(1.0);
478 } else {
479 let params = &sojourn_params[state];
480 let surv = match params.distribution {
481 SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
482 weibull_survival(t, params.shape, params.scale)
483 }
484 SojournDistribution::Exponential => (-t / params.scale).exp(),
485 SojournDistribution::LogNormal => {
486 1.0 - lognormal_cdf(t, params.location, params.shape)
487 }
488 SojournDistribution::Gamma => (-t * params.shape / params.scale).exp(),
489 };
490 probs[state] = surv * (1.0 - probs.iter().sum::<f64>()).max(0.0);
491 }
492 }
493 }
494
495 let sum: f64 = probs.iter().sum();
496 if sum > 0.0 {
497 for p in &mut probs {
498 *p /= sum;
499 }
500 }
501 state_occupation_probs.push(probs);
502 }
503
504 let mean_sojourn_times: Vec<f64> = sojourn_params.iter().map(|p| p.mean).collect();
505
506 let n_obs = n as f64;
507 let aic = -2.0 * log_likelihood + 2.0 * n_params as f64;
508 let bic = -2.0 * log_likelihood + (n_params as f64) * n_obs.ln();
509
510 Ok(SemiMarkovResult {
511 transition_probs,
512 sojourn_params,
513 state_occupation_probs,
514 time_points,
515 mean_sojourn_times,
516 n_transitions: transition_counts,
517 log_likelihood,
518 aic,
519 bic,
520 })
521}
522
523#[derive(Debug, Clone)]
524#[pyclass(from_py_object)]
525pub struct SemiMarkovPrediction {
526 #[pyo3(get)]
527 pub state_probs: Vec<Vec<f64>>,
528 #[pyo3(get)]
529 pub time_points: Vec<f64>,
530 #[pyo3(get)]
531 pub expected_sojourn: Vec<f64>,
532 #[pyo3(get)]
533 pub transition_hazards: HashMap<String, Vec<f64>>,
534}
535
536#[pyfunction]
537#[pyo3(signature = (model, current_state, time_in_state, prediction_times))]
538pub fn predict_semi_markov(
539 model: &SemiMarkovResult,
540 current_state: usize,
541 time_in_state: f64,
542 prediction_times: Vec<f64>,
543) -> PyResult<SemiMarkovPrediction> {
544 let n_states = model.sojourn_params.len();
545 if current_state >= n_states {
546 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
547 "current_state must be less than number of states",
548 ));
549 }
550
551 let params = &model.sojourn_params[current_state];
552 let current_survival = match params.distribution {
553 SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
554 weibull_survival(time_in_state, params.shape, params.scale)
555 }
556 SojournDistribution::Exponential => (-time_in_state / params.scale).exp(),
557 _ => 1.0,
558 };
559
560 let mut state_probs: Vec<Vec<f64>> = Vec::new();
561 let mut transition_hazards: HashMap<String, Vec<f64>> = HashMap::new();
562
563 for to_state in 0..n_states {
564 let key = format!("{}_{}", current_state, to_state);
565 transition_hazards.insert(key.clone(), Vec::new());
566 }
567
568 for &t in &prediction_times {
569 let total_time = time_in_state + t;
570 let mut probs = vec![0.0; n_states];
571
572 let future_survival = match params.distribution {
573 SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
574 weibull_survival(total_time, params.shape, params.scale)
575 }
576 SojournDistribution::Exponential => (-total_time / params.scale).exp(),
577 _ => 1.0,
578 };
579
580 let conditional_survival = if current_survival > 1e-10 {
581 future_survival / current_survival
582 } else {
583 0.0
584 };
585
586 probs[current_state] = conditional_survival;
587
588 let exit_prob = 1.0 - conditional_survival;
589 for (to_state, prob) in probs.iter_mut().enumerate().take(n_states) {
590 if to_state != current_state {
591 let key = format!("{}_{}", current_state, to_state);
592 let trans_prob = *model.transition_probs.get(&key).unwrap_or(&0.0);
593 *prob = exit_prob * trans_prob;
594 }
595 }
596
597 state_probs.push(probs);
598
599 for to_state in 0..n_states {
600 let key = format!("{}_{}", current_state, to_state);
601 let trans_prob = *model.transition_probs.get(&key).unwrap_or(&0.0);
602
603 let hazard = if conditional_survival > 1e-10 {
604 let pdf = match params.distribution {
605 SojournDistribution::Weibull | SojournDistribution::GeneralizedGamma => {
606 weibull_pdf(total_time, params.shape, params.scale)
607 }
608 SojournDistribution::Exponential => {
609 (1.0 / params.scale) * (-total_time / params.scale).exp()
610 }
611 _ => 0.0,
612 };
613 trans_prob * pdf / future_survival
614 } else {
615 0.0
616 };
617
618 if let Some(hazards) = transition_hazards.get_mut(&key) {
619 hazards.push(hazard);
620 } else {
621 return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
622 "internal error: missing transition hazard bucket",
623 ));
624 }
625 }
626 }
627
628 let expected_sojourn: Vec<f64> = model.sojourn_params.iter().map(|p| p.mean).collect();
629
630 Ok(SemiMarkovPrediction {
631 state_probs,
632 time_points: prediction_times,
633 expected_sojourn,
634 transition_hazards,
635 })
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641
642 #[test]
643 fn test_semi_markov_config() {
644 let config = SemiMarkovConfig::new(3, None, None, None, 100, 1e-6);
645 assert_eq!(config.n_states, 3);
646 assert_eq!(config.state_names.len(), 3);
647 assert_eq!(config.sojourn_distributions.len(), 3);
648 }
649
650 #[test]
651 fn test_weibull_functions() {
652 let pdf = weibull_pdf(1.0, 2.0, 1.0);
653 assert!(pdf > 0.0 && pdf < 1.0);
654
655 let cdf = weibull_cdf(1.0, 2.0, 1.0);
656 assert!(cdf > 0.0 && cdf < 1.0);
657
658 let surv = weibull_survival(1.0, 2.0, 1.0);
659 assert!((surv + cdf - 1.0).abs() < 1e-10);
660 }
661
662 #[test]
663 fn test_fit_semi_markov() {
664 let entry_times = vec![0.0, 1.0, 2.0, 3.0, 0.0, 1.5, 2.5, 3.5];
665 let exit_times = vec![1.0, 2.0, 3.0, 4.0, 1.5, 2.5, 3.5, 5.0];
666 let from_states = vec![0, 0, 1, 1, 0, 0, 1, 1];
667 let to_states = vec![1, 1, 2, 2, 1, 1, 2, 2];
668
669 let config = SemiMarkovConfig::new(3, None, None, Some(vec![2]), 100, 1e-6);
670 let result =
671 fit_semi_markov(entry_times, exit_times, from_states, to_states, &config).unwrap();
672
673 assert_eq!(result.sojourn_params.len(), 3);
674 assert!(!result.transition_probs.is_empty());
675 assert!(result.log_likelihood.is_finite());
676 }
677
678 #[test]
679 fn test_predict_semi_markov() {
680 let entry_times = vec![0.0, 1.0, 2.0, 3.0];
681 let exit_times = vec![1.0, 2.0, 3.0, 4.0];
682 let from_states = vec![0, 0, 1, 1];
683 let to_states = vec![1, 1, 2, 2];
684
685 let config = SemiMarkovConfig::new(3, None, None, Some(vec![2]), 100, 1e-6);
686 let model =
687 fit_semi_markov(entry_times, exit_times, from_states, to_states, &config).unwrap();
688
689 let prediction = predict_semi_markov(&model, 0, 0.5, vec![0.5, 1.0, 1.5, 2.0]).unwrap();
690
691 assert_eq!(prediction.state_probs.len(), 4);
692 assert_eq!(prediction.time_points.len(), 4);
693 }
694}