winsfs_core/sfs/em.rs
1use std::io;
2
3use rayon::iter::{IndexedParallelIterator, ParallelIterator};
4
5use crate::{
6 em::{
7 likelihood::{LogLikelihood, SumOf},
8 EmSite, StreamEmSite,
9 },
10 io::ReadSite,
11 saf::iter::{IntoParallelSiteIterator, IntoSiteIterator},
12};
13
14use super::{Sfs, USfs};
15
16/// The minimum allowable SFS value during EM.
17const RESTRICT_MIN: f64 = f64::EPSILON;
18
19impl<const D: usize> Sfs<D> {
20 /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
21 /// in each frequency bin given the SFS and the input.
22 ///
23 /// This corresponds to an E-step for the EM algorithm. The returned SFS corresponds to the
24 /// expected number of sites in each bin given `self` and the `input`.
25 /// The sum of the returned SFS will be equal to the number of sites in the input.
26 ///
27 /// # Panics
28 ///
29 /// Panics if any of the sites in the input does not fit the shape of `self`.
30 ///
31 /// # Examples
32 ///
33 /// ```
34 /// use winsfs_core::{sfs::Sfs, saf1d, sfs1d};
35 /// let sfs = Sfs::uniform([5]);
36 /// let saf = saf1d![
37 /// [1., 0., 0., 0., 0.],
38 /// [0., 1., 0., 0., 0.],
39 /// [1., 0., 0., 0., 0.],
40 /// [0., 0., 0., 1., 0.],
41 /// ];
42 /// let (log_likelihood, posterior) = sfs.clone().e_step(&saf);
43 /// assert_eq!(posterior, sfs1d![2., 1., 0., 1., 0.]);
44 /// assert_eq!(log_likelihood, sfs.log_likelihood(&saf));
45 /// ```
46 pub fn e_step<I>(mut self, input: I) -> (SumOf<LogLikelihood>, USfs<D>)
47 where
48 I: IntoSiteIterator<D>,
49 I::Item: EmSite<D>,
50 {
51 self = restrict(self, RESTRICT_MIN);
52 let iter = input.into_site_iter();
53 let sites = iter.len();
54
55 let (log_likelihood, posterior, _) = iter.fold(
56 (
57 LogLikelihood::from(0.0),
58 USfs::zeros(self.shape),
59 USfs::zeros(self.shape),
60 ),
61 |(mut log_likelihood, mut posterior, mut buf), site| {
62 log_likelihood += site.posterior_into(&self, &mut posterior, &mut buf).ln();
63
64 (log_likelihood, posterior, buf)
65 },
66 );
67
68 (SumOf::new(log_likelihood, sites), posterior)
69 }
70
71 /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
72 /// in each frequency bin given the SFS and the input.
73 ///
74 /// This is the parallel version of [`Sfs::e_step`], see also its documentation for more.
75 ///
76 /// # Panics
77 ///
78 /// Panics if any of the sites in the input does not fit the shape of `self`.
79 ///
80 /// # Examples
81 ///
82 /// ```
83 /// use winsfs_core::{sfs::Sfs, saf1d, sfs1d};
84 /// let sfs = Sfs::uniform([5]);
85 /// let saf = saf1d![
86 /// [1., 0., 0., 0., 0.],
87 /// [0., 1., 0., 0., 0.],
88 /// [1., 0., 0., 0., 0.],
89 /// [0., 0., 0., 1., 0.],
90 /// ];
91 /// let (log_likelihood, posterior) = sfs.clone().par_e_step(&saf);
92 /// assert_eq!(posterior, sfs1d![2., 1., 0., 1., 0.]);
93 /// assert_eq!(log_likelihood, sfs.log_likelihood(&saf));
94 /// ```
95 pub fn par_e_step<I>(mut self, input: I) -> (SumOf<LogLikelihood>, USfs<D>)
96 where
97 I: IntoParallelSiteIterator<D>,
98 I::Item: EmSite<D>,
99 {
100 self = restrict(self, RESTRICT_MIN);
101 let iter = input.into_par_site_iter();
102 let sites = iter.len();
103
104 let (log_likelihood, posterior) = iter
105 .fold(
106 || {
107 (
108 LogLikelihood::from(0.0),
109 USfs::zeros(self.shape),
110 USfs::zeros(self.shape),
111 )
112 },
113 |(mut log_likelihood, mut posterior, mut buf), site| {
114 log_likelihood += site.posterior_into(&self, &mut posterior, &mut buf).ln();
115
116 (log_likelihood, posterior, buf)
117 },
118 )
119 .map(|(log_likelihood, posterior, _buf)| (log_likelihood, posterior))
120 .reduce(
121 || (LogLikelihood::from(0.0), USfs::zeros(self.shape)),
122 |a, b| (a.0 + b.0, a.1 + b.1),
123 );
124
125 (SumOf::new(log_likelihood, sites), posterior)
126 }
127
128 /// Returns the log-likelihood of the data given the SFS.
129 ///
130 /// # Panics
131 ///
132 /// Panics if any of the sites in the input does not fit the shape of `self`.
133 ///
134 /// # Examples
135 ///
136 /// ```
137 /// use winsfs_core::{em::likelihood::{Likelihood, SumOf}, sfs::Sfs, saf1d, sfs1d};
138 /// let sfs = Sfs::uniform([5]);
139 /// let saf = saf1d![
140 /// [1., 0., 0., 0., 0.],
141 /// [0., 1., 0., 0., 0.],
142 /// [1., 0., 0., 0., 0.],
143 /// [0., 0., 0., 1., 0.],
144 /// ];
145 /// let expected = SumOf::new(Likelihood::from(0.2f64.powi(4)).ln(), saf.sites());
146 /// assert_eq!(sfs.log_likelihood(&saf), expected);
147 /// ```
148 pub fn log_likelihood<I>(mut self, input: I) -> SumOf<LogLikelihood>
149 where
150 I: IntoSiteIterator<D>,
151 {
152 self = restrict(self, RESTRICT_MIN);
153 let iter = input.into_site_iter();
154 let sites = iter.len();
155
156 let log_likelihood = iter.fold(LogLikelihood::from(0.0), |log_likelihood, site| {
157 log_likelihood + site.log_likelihood(&self)
158 });
159
160 SumOf::new(log_likelihood, sites)
161 }
162
163 /// Returns the log-likelihood of the data given the SFS.
164 ///
165 /// This is the parallel version of [`Sfs::log_likelihood`].
166 ///
167 /// # Panics
168 ///
169 /// Panics if any of the sites in the input does not fit the shape of `self`.
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use winsfs_core::{em::likelihood::{Likelihood, SumOf}, sfs::Sfs, saf1d, sfs1d};
175 /// let sfs = Sfs::uniform([5]);
176 /// let saf = saf1d![
177 /// [1., 0., 0., 0., 0.],
178 /// [0., 1., 0., 0., 0.],
179 /// [1., 0., 0., 0., 0.],
180 /// [0., 0., 0., 1., 0.],
181 /// ];
182 /// let expected = SumOf::new(Likelihood::from(0.2f64.powi(4)).ln(), saf.sites());
183 /// assert_eq!(sfs.par_log_likelihood(&saf), expected);
184 /// ```
185 pub fn par_log_likelihood<I>(mut self, input: I) -> SumOf<LogLikelihood>
186 where
187 I: IntoParallelSiteIterator<D>,
188 {
189 self = restrict(self, RESTRICT_MIN);
190 let iter = input.into_par_site_iter();
191 let sites = iter.len();
192
193 let log_likelihood = iter
194 .fold(
195 || LogLikelihood::from(0.0),
196 |log_likelihood, site| log_likelihood + site.log_likelihood(&self),
197 )
198 .sum();
199
200 SumOf::new(log_likelihood, sites)
201 }
202
203 /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
204 /// in each frequency bin given the SFS and the input.
205 ///
206 /// This is the streaming version of [`Sfs::e_step`], see also its documentation for more.
207 ///
208 /// # Panics
209 ///
210 /// Panics if any of the sites in the input does not fit the shape of `self`.
211 pub fn stream_e_step<R>(mut self, mut reader: R) -> io::Result<(SumOf<LogLikelihood>, USfs<D>)>
212 where
213 R: ReadSite,
214 R::Site: StreamEmSite<D>,
215 {
216 self = restrict(self, RESTRICT_MIN);
217 let mut post = USfs::zeros(self.shape);
218 let mut buf = USfs::zeros(self.shape);
219
220 let mut site = <R::Site>::from_shape(self.shape);
221
222 let mut sites = 0;
223 let mut log_likelihood = LogLikelihood::from(0.0);
224 while reader.read_site(&mut site)?.is_not_done() {
225 log_likelihood += site.posterior_into(&self, &mut post, &mut buf).ln();
226
227 sites += 1;
228 }
229
230 Ok((SumOf::new(log_likelihood, sites), post))
231 }
232
233 /// Returns the log-likelihood of the data given the SFS.
234 ///
235 /// This is the streaming version of [`Sfs::log_likelihood`].
236 ///
237 /// # Panics
238 ///
239 /// Panics if any of the sites in the input does not fit the shape of `self`.
240 pub fn stream_log_likelihood<R>(mut self, mut reader: R) -> io::Result<SumOf<LogLikelihood>>
241 where
242 R: ReadSite,
243 R::Site: StreamEmSite<D>,
244 {
245 self = restrict(self, RESTRICT_MIN);
246 let mut site = <R::Site>::from_shape(self.shape);
247
248 let mut sites = 0;
249 let mut log_likelihood = LogLikelihood::from(0.0);
250 while reader.read_site(&mut site)?.is_not_done() {
251 log_likelihood += site.log_likelihood(&self);
252
253 sites += 1;
254 }
255
256 Ok(SumOf::new(log_likelihood, sites))
257 }
258}
259
260/// Restricts the SFS so that all values in the spectrum are above `min`.
261///
262/// We have to ensure that that no value in the SFS is zero. If that happens, a situation can
263/// arise in which a site arrives with information only in the part of the SFS that is zero: this
264/// will lead to a zero posterior that cannot be normalised.
265fn restrict<const D: usize>(mut sfs: Sfs<D>, min: f64) -> Sfs<D> {
266 sfs.values.iter_mut().for_each(|v| {
267 if *v < min {
268 *v = min;
269 }
270 });
271
272 sfs
273}