1use std::io;
4
5pub mod likelihood;
6
7mod adaptors;
8pub use adaptors::Inspect;
9
10mod site;
11pub use site::{EmSite, StreamEmSite};
12
13mod standard_em;
14pub use standard_em::{ParallelStandardEm, StandardEm};
15
16pub mod stopping;
17use stopping::Stop;
18
19mod window_em;
20pub use window_em::{Window, WindowEm};
21
22use crate::{
23 io::Rewind,
24 sfs::{Sfs, USfs},
25};
26
27pub trait EmStep: Sized {
32 type Status;
38
39 fn inspect<const N: usize, F>(self, f: F) -> Inspect<Self, F>
41 where
42 F: FnMut(&Self, &Self::Status, &USfs<N>),
43 {
44 Inspect::new(self, f)
45 }
46}
47
48pub trait Em<const N: usize, I>: EmStep {
50 fn e_step(&mut self, sfs: Sfs<N>, input: &I) -> (Self::Status, USfs<N>);
58
59 fn em_step(&mut self, sfs: Sfs<N>, input: &I) -> (Self::Status, Sfs<N>) {
67 let (status, posterior) = self.e_step(sfs, input);
68
69 (status, posterior.normalise())
70 }
71
72 fn em<S>(&mut self, mut sfs: Sfs<N>, input: &I, mut stopping_rule: S) -> (Self::Status, Sfs<N>)
81 where
82 S: Stop<Self, Status = Self::Status>,
83 {
84 loop {
85 let (status, new_sfs) = self.em_step(sfs, input);
86 sfs = new_sfs;
87
88 if stopping_rule.stop(self, &status, &sfs) {
89 break (status, sfs);
90 }
91 }
92 }
93}
94
95pub trait StreamingEm<const D: usize, R>: EmStep
97where
98 R: Rewind,
99 R::Site: EmSite<D>,
100{
101 fn stream_e_step(&mut self, sfs: Sfs<D>, reader: &mut R)
109 -> io::Result<(Self::Status, USfs<D>)>;
110
111 fn stream_em_step(
119 &mut self,
120 sfs: Sfs<D>,
121 reader: &mut R,
122 ) -> io::Result<(Self::Status, Sfs<D>)> {
123 let (status, posterior) = self.stream_e_step(sfs, reader)?;
124
125 Ok((status, posterior.normalise()))
126 }
127
128 fn stream_em<S>(
137 &mut self,
138 mut sfs: Sfs<D>,
139 reader: &mut R,
140 mut stopping_rule: S,
141 ) -> io::Result<(Self::Status, Sfs<D>)>
142 where
143 S: Stop<Self, Status = Self::Status>,
144 {
145 loop {
146 let (status, new_sfs) = self.stream_em_step(sfs, reader)?;
147 sfs = new_sfs;
148
149 if stopping_rule.stop(self, &status, &sfs) {
150 break Ok((status, sfs));
151 } else {
152 reader.rewind()?;
153 }
154 }
155 }
156}
157
158pub(self) fn to_f64(x: usize) -> f64 {
159 let result = x as f64;
160 if result as usize != x {
161 panic!("cannot convert {x} (usize) into f64");
162 }
163 result
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 use crate::{saf::Saf, saf1d, sfs1d};
171
172 fn impl_test_em_zero_not_nan<E>(mut runner: E)
173 where
174 E: Em<1, Saf<1>>,
175 {
176 let saf = saf1d![[0., 0., 1.]];
177 let init_sfs = sfs1d![1., 0., 0.].into_normalised().unwrap();
178
179 let (_, sfs) = runner.em(init_sfs, &saf, stopping::Steps::new(1));
180
181 let has_nan = sfs.iter().any(|x| x.is_nan());
182 assert!(!has_nan);
183 }
184
185 #[test]
186 fn test_em_zero_sfs_not_nan() {
187 impl_test_em_zero_not_nan(StandardEm::<false>::new())
188 }
189
190 #[test]
191 fn test_parallel_em_zero_sfs_not_nan() {
192 impl_test_em_zero_not_nan(ParallelStandardEm::new())
193 }
194
195 #[test]
196 fn test_window_em_zero_sfs_not_nan() {
197 impl_test_em_zero_not_nan(WindowEm::new(
198 StandardEm::<false>::new(),
199 Window::from_zeros([3], 1),
200 1,
201 ))
202 }
203}