winsfs_core/sfs/
em.rs

1use std::io;
2
3use rayon::iter::{IndexedParallelIterator, ParallelIterator};
4
5use crate::{
6    em::{
7        likelihood::{LogLikelihood, SumOf},
8        EmSite, StreamEmSite,
9    },
10    io::ReadSite,
11    saf::iter::{IntoParallelSiteIterator, IntoSiteIterator},
12};
13
14use super::{Sfs, USfs};
15
16/// The minimum allowable SFS value during EM.
17const RESTRICT_MIN: f64 = f64::EPSILON;
18
19impl<const D: usize> Sfs<D> {
20    /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
21    /// in each frequency bin given the SFS and the input.
22    ///
23    /// This corresponds to an E-step for the EM algorithm. The returned SFS corresponds to the
24    /// expected number of sites in each bin given `self` and the `input`.
25    /// The sum of the returned SFS will be equal to the number of sites in the input.
26    ///
27    /// # Panics
28    ///
29    /// Panics if any of the sites in the input does not fit the shape of `self`.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use winsfs_core::{sfs::Sfs, saf1d, sfs1d};
35    /// let sfs = Sfs::uniform([5]);
36    /// let saf = saf1d![
37    ///     [1., 0., 0., 0., 0.],
38    ///     [0., 1., 0., 0., 0.],
39    ///     [1., 0., 0., 0., 0.],
40    ///     [0., 0., 0., 1., 0.],
41    /// ];
42    /// let (log_likelihood, posterior) = sfs.clone().e_step(&saf);
43    /// assert_eq!(posterior, sfs1d![2., 1., 0., 1., 0.]);
44    /// assert_eq!(log_likelihood, sfs.log_likelihood(&saf));
45    /// ```
46    pub fn e_step<I>(mut self, input: I) -> (SumOf<LogLikelihood>, USfs<D>)
47    where
48        I: IntoSiteIterator<D>,
49        I::Item: EmSite<D>,
50    {
51        self = restrict(self, RESTRICT_MIN);
52        let iter = input.into_site_iter();
53        let sites = iter.len();
54
55        let (log_likelihood, posterior, _) = iter.fold(
56            (
57                LogLikelihood::from(0.0),
58                USfs::zeros(self.shape),
59                USfs::zeros(self.shape),
60            ),
61            |(mut log_likelihood, mut posterior, mut buf), site| {
62                log_likelihood += site.posterior_into(&self, &mut posterior, &mut buf).ln();
63
64                (log_likelihood, posterior, buf)
65            },
66        );
67
68        (SumOf::new(log_likelihood, sites), posterior)
69    }
70
71    /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
72    /// in each frequency bin given the SFS and the input.
73    ///
74    /// This is the parallel version of [`Sfs::e_step`], see also its documentation for more.
75    ///
76    /// # Panics
77    ///
78    /// Panics if any of the sites in the input does not fit the shape of `self`.
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use winsfs_core::{sfs::Sfs, saf1d, sfs1d};
84    /// let sfs = Sfs::uniform([5]);
85    /// let saf = saf1d![
86    ///     [1., 0., 0., 0., 0.],
87    ///     [0., 1., 0., 0., 0.],
88    ///     [1., 0., 0., 0., 0.],
89    ///     [0., 0., 0., 1., 0.],
90    /// ];
91    /// let (log_likelihood, posterior) = sfs.clone().par_e_step(&saf);
92    /// assert_eq!(posterior, sfs1d![2., 1., 0., 1., 0.]);
93    /// assert_eq!(log_likelihood, sfs.log_likelihood(&saf));
94    /// ```
95    pub fn par_e_step<I>(mut self, input: I) -> (SumOf<LogLikelihood>, USfs<D>)
96    where
97        I: IntoParallelSiteIterator<D>,
98        I::Item: EmSite<D>,
99    {
100        self = restrict(self, RESTRICT_MIN);
101        let iter = input.into_par_site_iter();
102        let sites = iter.len();
103
104        let (log_likelihood, posterior) = iter
105            .fold(
106                || {
107                    (
108                        LogLikelihood::from(0.0),
109                        USfs::zeros(self.shape),
110                        USfs::zeros(self.shape),
111                    )
112                },
113                |(mut log_likelihood, mut posterior, mut buf), site| {
114                    log_likelihood += site.posterior_into(&self, &mut posterior, &mut buf).ln();
115
116                    (log_likelihood, posterior, buf)
117                },
118            )
119            .map(|(log_likelihood, posterior, _buf)| (log_likelihood, posterior))
120            .reduce(
121                || (LogLikelihood::from(0.0), USfs::zeros(self.shape)),
122                |a, b| (a.0 + b.0, a.1 + b.1),
123            );
124
125        (SumOf::new(log_likelihood, sites), posterior)
126    }
127
128    /// Returns the log-likelihood of the data given the SFS.
129    ///
130    /// # Panics
131    ///
132    /// Panics if any of the sites in the input does not fit the shape of `self`.
133    ///
134    /// # Examples
135    ///
136    /// ```
137    /// use winsfs_core::{em::likelihood::{Likelihood, SumOf}, sfs::Sfs, saf1d, sfs1d};
138    /// let sfs = Sfs::uniform([5]);
139    /// let saf = saf1d![
140    ///     [1., 0., 0., 0., 0.],
141    ///     [0., 1., 0., 0., 0.],
142    ///     [1., 0., 0., 0., 0.],
143    ///     [0., 0., 0., 1., 0.],
144    /// ];
145    /// let expected = SumOf::new(Likelihood::from(0.2f64.powi(4)).ln(), saf.sites());
146    /// assert_eq!(sfs.log_likelihood(&saf), expected);
147    /// ```
148    pub fn log_likelihood<I>(mut self, input: I) -> SumOf<LogLikelihood>
149    where
150        I: IntoSiteIterator<D>,
151    {
152        self = restrict(self, RESTRICT_MIN);
153        let iter = input.into_site_iter();
154        let sites = iter.len();
155
156        let log_likelihood = iter.fold(LogLikelihood::from(0.0), |log_likelihood, site| {
157            log_likelihood + site.log_likelihood(&self)
158        });
159
160        SumOf::new(log_likelihood, sites)
161    }
162
163    /// Returns the log-likelihood of the data given the SFS.
164    ///
165    /// This is the parallel version of [`Sfs::log_likelihood`].
166    ///
167    /// # Panics
168    ///
169    /// Panics if any of the sites in the input does not fit the shape of `self`.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use winsfs_core::{em::likelihood::{Likelihood, SumOf}, sfs::Sfs, saf1d, sfs1d};
175    /// let sfs = Sfs::uniform([5]);
176    /// let saf = saf1d![
177    ///     [1., 0., 0., 0., 0.],
178    ///     [0., 1., 0., 0., 0.],
179    ///     [1., 0., 0., 0., 0.],
180    ///     [0., 0., 0., 1., 0.],
181    /// ];
182    /// let expected = SumOf::new(Likelihood::from(0.2f64.powi(4)).ln(), saf.sites());
183    /// assert_eq!(sfs.par_log_likelihood(&saf), expected);
184    /// ```
185    pub fn par_log_likelihood<I>(mut self, input: I) -> SumOf<LogLikelihood>
186    where
187        I: IntoParallelSiteIterator<D>,
188    {
189        self = restrict(self, RESTRICT_MIN);
190        let iter = input.into_par_site_iter();
191        let sites = iter.len();
192
193        let log_likelihood = iter
194            .fold(
195                || LogLikelihood::from(0.0),
196                |log_likelihood, site| log_likelihood + site.log_likelihood(&self),
197            )
198            .sum();
199
200        SumOf::new(log_likelihood, sites)
201    }
202
203    /// Returns the log-likelihood of the data given the SFS, and the expected number of sites
204    /// in each frequency bin given the SFS and the input.
205    ///
206    /// This is the streaming version of [`Sfs::e_step`], see also its documentation for more.
207    ///
208    /// # Panics
209    ///
210    /// Panics if any of the sites in the input does not fit the shape of `self`.
211    pub fn stream_e_step<R>(mut self, mut reader: R) -> io::Result<(SumOf<LogLikelihood>, USfs<D>)>
212    where
213        R: ReadSite,
214        R::Site: StreamEmSite<D>,
215    {
216        self = restrict(self, RESTRICT_MIN);
217        let mut post = USfs::zeros(self.shape);
218        let mut buf = USfs::zeros(self.shape);
219
220        let mut site = <R::Site>::from_shape(self.shape);
221
222        let mut sites = 0;
223        let mut log_likelihood = LogLikelihood::from(0.0);
224        while reader.read_site(&mut site)?.is_not_done() {
225            log_likelihood += site.posterior_into(&self, &mut post, &mut buf).ln();
226
227            sites += 1;
228        }
229
230        Ok((SumOf::new(log_likelihood, sites), post))
231    }
232
233    /// Returns the log-likelihood of the data given the SFS.
234    ///
235    /// This is the streaming version of [`Sfs::log_likelihood`].
236    ///
237    /// # Panics
238    ///
239    /// Panics if any of the sites in the input does not fit the shape of `self`.
240    pub fn stream_log_likelihood<R>(mut self, mut reader: R) -> io::Result<SumOf<LogLikelihood>>
241    where
242        R: ReadSite,
243        R::Site: StreamEmSite<D>,
244    {
245        self = restrict(self, RESTRICT_MIN);
246        let mut site = <R::Site>::from_shape(self.shape);
247
248        let mut sites = 0;
249        let mut log_likelihood = LogLikelihood::from(0.0);
250        while reader.read_site(&mut site)?.is_not_done() {
251            log_likelihood += site.log_likelihood(&self);
252
253            sites += 1;
254        }
255
256        Ok(SumOf::new(log_likelihood, sites))
257    }
258}
259
260/// Restricts the SFS so that all values in the spectrum are above `min`.
261///
262/// We have to ensure that that no value in the SFS is zero. If that happens, a situation can
263/// arise in which a site arrives with information only in the part of the SFS that is zero: this
264/// will lead to a zero posterior that cannot be normalised.
265fn restrict<const D: usize>(mut sfs: Sfs<D>, min: f64) -> Sfs<D> {
266    sfs.values.iter_mut().for_each(|v| {
267        if *v < min {
268            *v = min;
269        }
270    });
271
272    sfs
273}