winsfs_core/
em.rs

1//! Expectation-maximisation ("EM") algorithms for SFS inference.
2
3use std::io;
4
5pub mod likelihood;
6
7mod adaptors;
8pub use adaptors::Inspect;
9
10mod site;
11pub use site::{EmSite, StreamEmSite};
12
13mod standard_em;
14pub use standard_em::{ParallelStandardEm, StandardEm};
15
16pub mod stopping;
17use stopping::Stop;
18
19mod window_em;
20pub use window_em::{Window, WindowEm};
21
22use crate::{
23    io::Rewind,
24    sfs::{Sfs, USfs},
25};
26
27/// An EM-like type that runs in steps.
28///
29/// This serves as a supertrait bound for both [`Em`] and [`StreamingEm`] and gathers
30/// behaviour shared around running consecutive EM-steps.
31pub trait EmStep: Sized {
32    /// The status returned after each step.
33    ///
34    /// This may be used, for example, to determine convergence by the stopping rule,
35    /// or can be logged using [`EmStep::inspect`]. An example of a status might
36    /// be the log-likelihood of the data given the SFS after the E-step.
37    type Status;
38
39    /// Inspect the status after each E-step.
40    fn inspect<const N: usize, F>(self, f: F) -> Inspect<Self, F>
41    where
42        F: FnMut(&Self, &Self::Status, &USfs<N>),
43    {
44        Inspect::new(self, f)
45    }
46}
47
48/// A type capable of running an EM-like algorithm for SFS inference using data in-memory.
49pub trait Em<const N: usize, I>: EmStep {
50    /// The E-step of the algorithm.
51    ///
52    /// This should correspond to a full pass over the `input`.
53    ///
54    /// # Panics
55    ///
56    /// Panics if the shapes of the SFS and the input do not match.
57    fn e_step(&mut self, sfs: Sfs<N>, input: &I) -> (Self::Status, USfs<N>);
58
59    /// A full EM-step of the algorithm.
60    ///
61    /// Like the [`Em::e_step`], this should correspond to a full pass over the `input`.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the shapes of the SFS and the input do not match.
66    fn em_step(&mut self, sfs: Sfs<N>, input: &I) -> (Self::Status, Sfs<N>) {
67        let (status, posterior) = self.e_step(sfs, input);
68
69        (status, posterior.normalise())
70    }
71
72    /// Runs the EM algorithm until convergence.
73    ///
74    /// This consists of running EM-steps until convergence, which is decided by the provided
75    /// `stopping_rule`. The converged SFS, and the status of the last EM-step, are returned.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the shapes of the SFS and the input do not match.
80    fn em<S>(&mut self, mut sfs: Sfs<N>, input: &I, mut stopping_rule: S) -> (Self::Status, Sfs<N>)
81    where
82        S: Stop<Self, Status = Self::Status>,
83    {
84        loop {
85            let (status, new_sfs) = self.em_step(sfs, input);
86            sfs = new_sfs;
87
88            if stopping_rule.stop(self, &status, &sfs) {
89                break (status, sfs);
90            }
91        }
92    }
93}
94
95/// A type capable of running an EM-like algorithm for SFS inference by streaming through data.
96pub trait StreamingEm<const D: usize, R>: EmStep
97where
98    R: Rewind,
99    R::Site: EmSite<D>,
100{
101    /// The E-step of the algorithm.
102    ///
103    /// This should correspond to a full pass through the `reader`.
104    ///
105    /// # Panics
106    ///
107    /// Panics if the shapes of the SFS and the input do not match.
108    fn stream_e_step(&mut self, sfs: Sfs<D>, reader: &mut R)
109        -> io::Result<(Self::Status, USfs<D>)>;
110
111    /// A full EM-step of the algorithm.
112    ///
113    /// Like the [`Em::e_step`], this should correspond to a full pass through the `reader`.
114    ///
115    /// # Panics
116    ///
117    /// Panics if the shapes of the SFS and the input do not match.
118    fn stream_em_step(
119        &mut self,
120        sfs: Sfs<D>,
121        reader: &mut R,
122    ) -> io::Result<(Self::Status, Sfs<D>)> {
123        let (status, posterior) = self.stream_e_step(sfs, reader)?;
124
125        Ok((status, posterior.normalise()))
126    }
127
128    /// Runs the EM algorithm until convergence.
129    ///
130    /// This consists of running EM-steps until convergence, which is decided by the provided
131    /// `stopping_rule`. The converged SFS, and the status of the last EM-step, are returned.
132    ///
133    /// # Panics
134    ///
135    /// Panics if the shapes of the SFS and the input do not match.
136    fn stream_em<S>(
137        &mut self,
138        mut sfs: Sfs<D>,
139        reader: &mut R,
140        mut stopping_rule: S,
141    ) -> io::Result<(Self::Status, Sfs<D>)>
142    where
143        S: Stop<Self, Status = Self::Status>,
144    {
145        loop {
146            let (status, new_sfs) = self.stream_em_step(sfs, reader)?;
147            sfs = new_sfs;
148
149            if stopping_rule.stop(self, &status, &sfs) {
150                break Ok((status, sfs));
151            } else {
152                reader.rewind()?;
153            }
154        }
155    }
156}
157
158pub(self) fn to_f64(x: usize) -> f64 {
159    let result = x as f64;
160    if result as usize != x {
161        panic!("cannot convert {x} (usize) into f64");
162    }
163    result
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    use crate::{saf::Saf, saf1d, sfs1d};
171
172    fn impl_test_em_zero_not_nan<E>(mut runner: E)
173    where
174        E: Em<1, Saf<1>>,
175    {
176        let saf = saf1d![[0., 0., 1.]];
177        let init_sfs = sfs1d![1., 0., 0.].into_normalised().unwrap();
178
179        let (_, sfs) = runner.em(init_sfs, &saf, stopping::Steps::new(1));
180
181        let has_nan = sfs.iter().any(|x| x.is_nan());
182        assert!(!has_nan);
183    }
184
185    #[test]
186    fn test_em_zero_sfs_not_nan() {
187        impl_test_em_zero_not_nan(StandardEm::<false>::new())
188    }
189
190    #[test]
191    fn test_parallel_em_zero_sfs_not_nan() {
192        impl_test_em_zero_not_nan(ParallelStandardEm::new())
193    }
194
195    #[test]
196    fn test_window_em_zero_sfs_not_nan() {
197        impl_test_em_zero_not_nan(WindowEm::new(
198            StandardEm::<false>::new(),
199            Window::from_zeros([3], 1),
200            1,
201        ))
202    }
203}