1use crate::{
2 saf::{AsSiteView, Site},
3 sfs::{Sfs, USfs},
4};
5
6use super::likelihood::{Likelihood, LogLikelihood};
7
8pub trait EmSite<const D: usize> {
13 fn likelihood(&self, sfs: &Sfs<D>) -> Likelihood;
28
29 fn log_likelihood(&self, sfs: &Sfs<D>) -> LogLikelihood {
44 self.likelihood(sfs).ln()
45 }
46
47 fn posterior_into(
60 &self,
61 sfs: &Sfs<D>,
62 posterior: &mut USfs<D>,
63 buf: &mut USfs<D>,
64 ) -> Likelihood;
65}
66
67impl<const D: usize, T> EmSite<D> for T
68where
69 T: AsSiteView<D>,
70{
71 fn likelihood(&self, sfs: &Sfs<D>) -> Likelihood {
72 let site = self.as_site_view();
73 assert_eq!(sfs.shape, site.shape());
74
75 let mut sum = 0.;
76
77 likelihood_inner(
78 sfs.as_slice(),
79 sfs.strides.as_slice(),
80 site.split().as_slice(),
81 &mut sum,
82 1.,
83 );
84
85 sum.into()
86 }
87
88 fn posterior_into(
89 &self,
90 sfs: &Sfs<D>,
91 posterior: &mut USfs<D>,
92 buf: &mut USfs<D>,
93 ) -> Likelihood {
94 let site = self.as_site_view();
95 assert_eq!(sfs.shape, site.shape());
96
97 let mut sum = 0.;
98
99 posterior_inner(
100 sfs.as_slice(),
101 sfs.strides.as_slice(),
102 site.split().as_slice(),
103 buf.as_mut_slice(),
104 &mut sum,
105 1.,
106 );
107
108 buf.iter_mut()
111 .zip(posterior.iter_mut())
112 .for_each(|(buf, posterior)| {
113 *buf /= sum;
114 *posterior += *buf;
115 });
116
117 sum.into()
118 }
119}
120
121pub trait StreamEmSite<const D: usize>: EmSite<D> {
126 fn from_shape(shape: [usize; D]) -> Self;
130}
131
132impl<const D: usize> StreamEmSite<D> for Site<D> {
133 fn from_shape(shape: [usize; D]) -> Self {
134 let vec = vec![0.0; shape.iter().sum()];
135 Site::new(vec, shape).unwrap()
136 }
137}
138
139fn likelihood_inner(sfs: &[f64], strides: &[usize], site: &[&[f32]], sum: &mut f64, acc: f64) {
143 match site {
144 &[hd] => sfs.iter().zip(hd).for_each(|(sfs, &saf)| {
145 *sum += sfs * saf as f64 * acc;
146 }),
147 [hd, cons @ ..] => {
148 let (stride, strides) = strides.split_first().expect("invalid strides");
149
150 for (i, &saf) in hd.iter().enumerate() {
151 let offset = i * stride;
152
153 likelihood_inner(&sfs[offset..], strides, cons, sum, saf as f64 * acc);
154 }
155 }
156 [] => (),
157 }
158}
159
160fn posterior_inner(
169 sfs: &[f64],
170 strides: &[usize],
171 site: &[&[f32]],
172 buf: &mut [f64],
173 sum: &mut f64,
174 acc: f64,
175) {
176 match site {
177 &[hd] => {
178 debug_assert_eq!(sfs.len(), hd.len());
181
182 buf.iter_mut()
183 .zip(sfs)
184 .zip(hd)
185 .for_each(|((buf, sfs), &saf)| {
186 let v = sfs * saf as f64 * acc;
187 *sum += v;
188 *buf = v
189 })
190 }
191 [hd, cons @ ..] => {
192 let (stride, strides) = strides.split_first().expect("invalid strides");
196
197 for (i, &saf) in hd.iter().enumerate() {
198 let offset = i * stride;
199
200 posterior_inner(
201 &sfs[offset..][..*stride],
202 strides,
203 cons,
204 &mut buf[offset..][..*stride],
205 sum,
206 saf as f64 * acc,
207 );
208 }
209 }
210 [] => (),
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 use crate::{saf::Site, sfs1d, sfs2d};
219
220 fn test_f64_equal(x: f64, y: f64, epsilon: f64) {
221 assert!((x - y).abs() < epsilon)
222 }
223
224 fn test_f64_slice_equal(xs: &[f64], ys: &[f64], epsilon: f64) {
225 assert_eq!(xs.len(), ys.len());
226
227 for (&x, &y) in xs.iter().zip(ys) {
228 test_f64_equal(x, y, epsilon)
229 }
230 }
231
232 #[test]
233 fn test_1d() {
234 let sfs = sfs1d![1., 2., 3.].normalise();
235
236 let site = Site::new(vec![2., 2., 2.], [3]).unwrap();
237 let mut posterior = sfs1d![10., 20., 30.];
238 let mut buf = USfs::zeros(sfs.shape);
239
240 let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
241
242 let expected = vec![10. + 1. / 6., 20. + 1. / 3., 30. + 1. / 2.];
243 test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), f64::EPSILON);
244
245 let likelihood = site.likelihood(&sfs);
246 test_f64_equal(likelihood.into(), 2., f64::EPSILON);
247 test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
248 }
249
250 #[test]
251 fn test_2d() {
252 #[rustfmt::skip]
253 let sfs = sfs2d![
254 [1., 2., 3., 4., 5.],
255 [6., 7., 8., 9., 10.],
256 [11., 12., 13., 14., 15.],
257 ].normalise();
258
259 let site = Site::new(vec![2., 2., 2., 2., 4., 6., 8., 10.], [3, 5]).unwrap();
260 let mut posterior = USfs::from_elem(1., sfs.shape);
261 let mut buf = USfs::zeros(sfs.shape);
262
263 let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
264
265 #[rustfmt::skip]
266 let expected = vec![
267 1.002564, 1.010256, 1.023077, 1.041026, 1.064103,
268 1.015385, 1.035897, 1.061538, 1.092308, 1.128205,
269 1.028205, 1.061538, 1.100000, 1.143590, 1.192308,
270 ];
271 test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), 1e-6);
272
273 let likelihood = site.likelihood(&sfs);
274 test_f64_equal(likelihood.into(), 13., f64::EPSILON);
275 test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
276 }
277
278 #[test]
279 fn test_3d() {
280 let sfs = USfs::from_vec_shape((0..60).map(|x| x as f64).collect(), [3, 4, 5])
281 .unwrap()
282 .normalise();
283
284 let site = Site::new((1..=12).map(|x| x as f32).collect(), [3, 4, 5]).unwrap();
285 let mut posterior = USfs::from_elem(1., sfs.shape);
286 let mut buf = USfs::zeros(sfs.shape);
287
288 let posterior_likelihood = site.posterior_into(&sfs, &mut posterior, &mut buf);
289
290 let expected = vec![
291 1.00000, 1.00015, 1.00032, 1.00053, 1.00078, 1.00081, 1.00109, 1.00141, 1.00178,
292 1.00218, 1.00194, 1.00240, 1.00291, 1.00347, 1.00407, 1.00339, 1.00407, 1.00481,
293 1.00560, 1.00645, 1.00517, 1.00611, 1.00711, 1.00818, 1.00931, 1.00808, 1.00945,
294 1.01091, 1.01244, 1.01406, 1.01164, 1.01353, 1.01551, 1.01760, 1.01978, 1.01584,
295 1.01833, 1.02093, 1.02364, 1.02647, 1.01551, 1.01789, 1.02036, 1.02293, 1.02560,
296 1.02182, 1.02509, 1.02848, 1.03200, 1.03563, 1.02909, 1.03338, 1.03782, 1.04240,
297 1.04712, 1.03733, 1.04276, 1.04836, 1.05413, 1.06007,
298 ];
299 test_f64_slice_equal(posterior.as_slice(), expected.as_slice(), 1e-5);
300
301 let likelihood = site.likelihood(&sfs);
302 test_f64_equal(likelihood.into(), 139.8418, 1e-4);
303 test_f64_equal(likelihood.into(), posterior_likelihood.into(), f64::EPSILON);
304 }
305}