winsfs_core/em/
site.rs

1use crate::{
2    saf::{AsSiteView, Site},
3    sfs::{Sfs, USfs},
4};
5
6use super::likelihood::{Likelihood, LogLikelihood};
7
8/// A type of SAF site that can be used as input for EM.
9///
10/// This trait should not typically be used in user code, except as a trait bound where code has to
11/// be written that is generic over different EM input types.
12pub trait EmSite<const D: usize> {
13    /// Returns the likelihood of a single site given the SFS.
14    ///
15    /// # Panics
16    ///
17    /// Panics if the shape of the SFS does not fit the shape of `self`.
18    ///
19    /// # Examples
20    ///
21    /// ```
22    /// use winsfs_core::{em::{likelihood::LogLikelihood, EmSite}, saf::Site, sfs::Sfs};
23    /// let sfs = Sfs::uniform([5]);
24    /// let site = Site::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], [5]).unwrap();
25    /// assert_eq!(site.log_likelihood(&sfs), LogLikelihood::from(0.2f64.ln()));
26    /// ```
27    fn likelihood(&self, sfs: &Sfs<D>) -> Likelihood;
28
29    /// Returns the log-likelihood of a single site given the SFS.
30    ///
31    /// # Panics
32    ///
33    /// Panics if the shape of the SFS does not fit the shape of `self`.
34    ///
35    /// # Examples
36    ///
37    /// ```
38    /// use winsfs_core::{em::{likelihood::LogLikelihood, EmSite}, saf::Site, sfs::Sfs};
39    /// let sfs = Sfs::uniform([5]);
40    /// let site = Site::new(vec![1.0, 0.0, 0.0, 0.0, 0.0], [5]).unwrap();
41    /// assert_eq!(site.log_likelihood(&sfs), LogLikelihood::from(0.2f64.ln()));
42    /// ```
43    fn log_likelihood(&self, sfs: &Sfs<D>) -> LogLikelihood {
44        self.likelihood(sfs).ln()
45    }
46
47    /// Adds the posterior counts for the site into the provided `posterior` buffer, using the
48    /// extra `buf` to avoid extraneous allocations.
49    ///
50    /// The `buf` will be overwritten, and so it's state is unimportant. The shape of the site
51    /// will be matched against the shape of the SFS, and a panic will be thrown if they do not
52    /// match. The shapes of `posterior` and `buf` are unchecked, but must match the shape of self.
53    ///
54    /// The likelihood of the site given the SFS is returned.
55    ///
56    /// # Panics
57    ///
58    /// Panics if the shape of the SFS does not fit the shape of `self`.
59    fn posterior_into(
60        &self,
61        sfs: &Sfs<D>,
62        posterior: &mut USfs<D>,
63        buf: &mut USfs<D>,
64    ) -> Likelihood;
65}
66
67impl<const D: usize, T> EmSite<D> for T
68where
69    T: AsSiteView<D>,
70{
71    fn likelihood(&self, sfs: &Sfs<D>) -> Likelihood {
72        let site = self.as_site_view();
73        assert_eq!(sfs.shape, site.shape());
74
75        let mut sum = 0.;
76
77        likelihood_inner(
78            sfs.as_slice(),
79            sfs.strides.as_slice(),
80            site.split().as_slice(),
81            &mut sum,
82            1.,
83        );
84
85        sum.into()
86    }
87
88    fn posterior_into(
89        &self,
90        sfs: &Sfs<D>,
91        posterior: &mut USfs<D>,
92        buf: &mut USfs<D>,
93    ) -> Likelihood {
94        let site = self.as_site_view();
95        assert_eq!(sfs.shape, site.shape());
96
97        let mut sum = 0.;
98
99        posterior_inner(
100            sfs.as_slice(),
101            sfs.strides.as_slice(),
102            site.split().as_slice(),
103            buf.as_mut_slice(),
104            &mut sum,
105            1.,
106        );
107
108        // Normalising and adding to the posterior in a single iterator has slightly better perf
109        // than normalising and then adding to posterior.
110        buf.iter_mut()
111            .zip(posterior.iter_mut())
112            .for_each(|(buf, posterior)| {
113                *buf /= sum;
114                *posterior += *buf;
115            });
116
117        sum.into()
118    }
119}
120
121/// A type of SAF site that can be used as input for streaming EM.
122///
123/// Like [`EmSite`], this trait should not typically be used in user code, except as a trait bound
124/// where code has to be written that is generic over different EM input types.
125pub trait StreamEmSite<const D: usize>: EmSite<D> {
126    /// Creates a new site from its shape.
127    ///
128    /// The returned site should be suitable for use as a read buffer.
129    fn from_shape(shape: [usize; D]) -> Self;
130}
131
132impl<const D: usize> StreamEmSite<D> for Site<D> {
133    fn from_shape(shape: [usize; D]) -> Self {
134        let vec = vec![0.0; shape.iter().sum()];
135        Site::new(vec, shape).unwrap()
136    }
137}
138
139/// Calculate the likelihood for a site any dimension recursively.
140///
141/// The logic here is a simplified version of `posterior_inner`: see the comments there for more.
142fn likelihood_inner(sfs: &[f64], strides: &[usize], site: &[&[f32]], sum: &mut f64, acc: f64) {
143    match site {
144        &[hd] => sfs.iter().zip(hd).for_each(|(sfs, &saf)| {
145            *sum += sfs * saf as f64 * acc;
146        }),
147        [hd, cons @ ..] => {
148            let (stride, strides) = strides.split_first().expect("invalid strides");
149
150            for (i, &saf) in hd.iter().enumerate() {
151                let offset = i * stride;
152
153                likelihood_inner(&sfs[offset..], strides, cons, sum, saf as f64 * acc);
154            }
155        }
156        [] => (),
157    }
158}
159
160/// Calculate the posterior for a site any dimension recursively.
161///
162/// The posterior is written into the `buf`, which is not normalised. The `sum` will contain
163/// the likelihood, which can be used to normalise. The passed-in `sum` should typically be zero,
164/// whereas the passed-in `acc` should typically be one.
165///
166/// It is  assumed that `sfs` and `buf` have the same length, which should correspond to the product
167/// of the length of the sites in `site`.
168fn posterior_inner(
169    sfs: &[f64],
170    strides: &[usize],
171    site: &[&[f32]],
172    buf: &mut [f64],
173    sum: &mut f64,
174    acc: f64,
175) {
176    match site {
177        &[hd] => {
178            // Base case: we have a single site, which signifies that the SFS slice
179            // now corresponds to a single slice along its last dimension, e.g. a row in 2D.
180            debug_assert_eq!(sfs.len(), hd.len());
181
182            buf.iter_mut()
183                .zip(sfs)
184                .zip(hd)
185                .for_each(|((buf, sfs), &saf)| {
186                    let v = sfs * saf as f64 * acc;
187                    *sum += v;
188                    *buf = v
189                })
190        }
191        [hd, cons @ ..] => {
192            // Recursive case: we have multiple sites. For each value in the first site,
193            // we add the value to the accumulant, "peel" the corresponding slice of the SFS,
194            // and recurse to a lower dimension.
195            let (stride, strides) = strides.split_first().expect("invalid strides");
196
197            for (i, &saf) in hd.iter().enumerate() {
198                let offset = i * stride;
199
200                posterior_inner(
201                    &sfs[offset..][..*stride],
202                    strides,
203                    cons,
204                    &mut buf[offset..][..*stride],
205                    sum,
206                    saf as f64 * acc,
207                );
208            }
209        }
210        [] => (),
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    use crate::{saf::Site, sfs1d, sfs2d};
219
220    fn test_f64_equal(x: f64, y: f64, epsilon: f64) {
221        assert!((x - y).abs() < epsilon)
222    }
223
224    fn test_f64_slice_equal(xs: &[f64], ys: &[f64], epsilon: f64) {
225        assert_eq!(xs.len(), ys.len());
226
227        for (&x, &y) in xs.iter().zip(ys) {
228            test_f64_equal(x, y, epsilon)
229        }
230    }
231
232    #[test]
233    fn test_1d() {
234        let sfs = sfs1d![1., 2., 3.].normalise();
235
236        let site = Site::new(vec![2., 2., 2.], [3]).unwrap();
237        let mut posterior = sfs1d![10., 20., 30.];
238        let mut buf = USfs::zeros(sfs.shape);
239
240        let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
241
242        let expected = vec![10. + 1. / 6., 20. + 1. / 3., 30. + 1. / 2.];
243        test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), f64::EPSILON);
244
245        let likelihood = site.likelihood(&sfs);
246        test_f64_equal(likelihood.into(), 2., f64::EPSILON);
247        test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
248    }
249
250    #[test]
251    fn test_2d() {
252        #[rustfmt::skip]
253        let sfs = sfs2d![
254            [1.,  2.,  3.,  4.,  5.],
255            [6.,  7.,  8.,  9.,  10.],
256            [11., 12., 13., 14., 15.],
257        ].normalise();
258
259        let site = Site::new(vec![2., 2., 2., 2., 4., 6., 8., 10.], [3, 5]).unwrap();
260        let mut posterior = USfs::from_elem(1., sfs.shape);
261        let mut buf = USfs::zeros(sfs.shape);
262
263        let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
264
265        #[rustfmt::skip]
266        let expected = vec![
267            1.002564, 1.010256, 1.023077, 1.041026, 1.064103,
268            1.015385, 1.035897, 1.061538, 1.092308, 1.128205,
269            1.028205, 1.061538, 1.100000, 1.143590, 1.192308,
270        ];
271        test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), 1e-6);
272
273        let likelihood = site.likelihood(&sfs);
274        test_f64_equal(likelihood.into(), 13., f64::EPSILON);
275        test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
276    }
277
278    #[test]
279    fn test_3d() {
280        let sfs = USfs::from_vec_shape((0..60).map(|x| x as f64).collect(), [3, 4, 5])
281            .unwrap()
282            .normalise();
283
284        let site = Site::new((1..=12).map(|x| x as f32).collect(), [3, 4, 5]).unwrap();
285        let mut posterior = USfs::from_elem(1., sfs.shape);
286        let mut buf = USfs::zeros(sfs.shape);
287
288        let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
289
290        let expected = vec![
291            1.00000, 1.00015, 1.00032, 1.00053, 1.00078, 1.00081, 1.00109, 1.00141, 1.00178,
292            1.00218, 1.00194, 1.00240, 1.00291, 1.00347, 1.00407, 1.00339, 1.00407, 1.00481,
293            1.00560, 1.00645, 1.00517, 1.00611, 1.00711, 1.00818, 1.00931, 1.00808, 1.00945,
294            1.01091, 1.01244, 1.01406, 1.01164, 1.01353, 1.01551, 1.01760, 1.01978, 1.01584,
295            1.01833, 1.02093, 1.02364, 1.02647, 1.01551, 1.01789, 1.02036, 1.02293, 1.02560,
296            1.02182, 1.02509, 1.02848, 1.03200, 1.03563, 1.02909, 1.03338, 1.03782, 1.04240,
297            1.04712, 1.03733, 1.04276, 1.04836, 1.05413, 1.06007,
298        ];
299        test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), 1e-5);
300
301        let likelihood = site.likelihood(&sfs);
302        test_f64_equal(likelihood.into(), 139.8418, 1e-4);
303        test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
304    }
305}