winsfs_core/em/
standard_em.rs

1use std::io;
2
3use crate::{
4    io::Rewind,
5    saf::iter::{IntoParallelSiteIterator, IntoSiteIterator},
6    sfs::{Sfs, USfs},
7};
8
9use super::{
10    likelihood::{LogLikelihood, SumOf},
11    Em, EmStep, StreamEmSite, StreamingEm,
12};
13
14/// A parallel runner of the standard EM algorithm.
15pub type ParallelStandardEm = StandardEm<true>;
16
17/// A runner of the standard EM algorithm.
18///
19/// Whether to parallelise over the input in the E-step is controlled by the `PAR` parameter.
20#[derive(Clone, Debug, Eq, PartialEq)]
21// TODO: Use an enum here when stable, see github.com/rust-lang/rust/issues/95174
22pub struct StandardEm<const PAR: bool = false> {
23    // Ensure unit struct cannot be constructed without constructor
24    _private: (),
25}
26
27impl<const PAR: bool> StandardEm<PAR> {
28    /// Returns a new instance of the runner.
29    pub fn new() -> Self {
30        Self { _private: () }
31    }
32}
33
34impl<const PAR: bool> Default for StandardEm<PAR> {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<const PAR: bool> EmStep for StandardEm<PAR> {
41    type Status = SumOf<LogLikelihood>;
42}
43
44impl<const D: usize, I> Em<D, I> for StandardEm<false>
45where
46    for<'a> &'a I: IntoSiteIterator<D>,
47{
48    fn e_step(&mut self, sfs: Sfs<D>, input: &I) -> (Self::Status, USfs<D>) {
49        sfs.e_step(input)
50    }
51}
52
53impl<const D: usize, R> StreamingEm<D, R> for StandardEm<false>
54where
55    R: Rewind,
56    R::Site: StreamEmSite<D>,
57{
58    fn stream_e_step(
59        &mut self,
60        sfs: Sfs<D>,
61        reader: &mut R,
62    ) -> io::Result<(Self::Status, USfs<D>)> {
63        sfs.stream_e_step(reader)
64    }
65}
66
67impl<const D: usize, I> Em<D, I> for StandardEm<true>
68where
69    for<'a> &'a I: IntoParallelSiteIterator<D>,
70{
71    fn e_step(&mut self, sfs: Sfs<D>, input: &I) -> (Self::Status, USfs<D>) {
72        sfs.par_e_step(input)
73    }
74}