winsfs_core/em/
window_em.rs

1use std::{collections::VecDeque, io, iter::repeat};
2
3use crate::{
4    io::{Enumerate, ReadSite, Rewind, Take},
5    saf::{iter::IntoBlockIterator, AsSafView},
6    sfs::{Sfs, USfs},
7};
8
9use super::{to_f64, Em, EmSite, EmStep, StreamingEm};
10
11/// A runner of the window EM algorithm.
12///
13/// The window EM algorithm updates the SFS estimate in smaller blocks of data, leading to multiple
14/// updates to the estimate per full EM-step. These block estimates are averaged over a sliding
15/// window to smooth the global estimate. The algorithm can be configured to use different EM-like
16/// algorithms (corresponding to the parameter `T`) for each inner block update step.
17#[derive(Clone, Debug, PartialEq)]
18pub struct WindowEm<const D: usize, T> {
19    em: T,
20    window: Window<D>,
21    block_size: usize,
22}
23
24impl<const D: usize, T> WindowEm<D, T> {
25    /// Returns a new instance of the runner.
26    ///
27    /// The `em` is the inner kind of EM to handle the blocks, and the `block_size` is the number
28    /// of sites per blocks. The provided `window` should match the shape of the input and SFS
29    /// that is provided for inference later. Where no good prior guess for the SFS exists,
30    /// using [`Window::from_zeros`] is recommended.
31    pub fn new(em: T, window: Window<D>, block_size: usize) -> Self {
32        Self {
33            em,
34            window,
35            block_size,
36        }
37    }
38}
39
40impl<const D: usize, T> EmStep for WindowEm<D, T>
41where
42    T: EmStep,
43{
44    type Status = Vec<T::Status>;
45}
46
47impl<const D: usize, I, T> Em<D, I> for WindowEm<D, T>
48where
49    for<'a> &'a I: IntoBlockIterator<D>,
50    for<'a> T: Em<D, <&'a I as IntoBlockIterator<D>>::Item>,
51{
52    fn e_step(&mut self, mut sfs: Sfs<D>, input: &I) -> (Self::Status, USfs<D>) {
53        let mut log_likelihoods = Vec::with_capacity(self.block_size);
54
55        let mut sites = 0;
56
57        for block in input.into_block_iter(self.block_size) {
58            sites += block.as_saf_view().sites();
59
60            let (log_likelihood, posterior) = self.em.e_step(sfs, &block);
61
62            self.window.update(posterior);
63
64            sfs = self.window.sum().normalise();
65            log_likelihoods.push(log_likelihood);
66        }
67
68        (log_likelihoods, sfs.scale(to_f64(sites)))
69    }
70}
71
72impl<const D: usize, R, T> StreamingEm<D, R> for WindowEm<D, T>
73where
74    R: Rewind,
75    R::Site: EmSite<D>,
76    for<'a> T: StreamingEm<D, Take<Enumerate<&'a mut R>>>,
77{
78    fn stream_e_step(
79        &mut self,
80        mut sfs: Sfs<D>,
81        reader: &mut R,
82    ) -> io::Result<(Self::Status, USfs<D>)> {
83        let mut log_likelihoods = Vec::with_capacity(self.block_size);
84
85        let mut sites = 0;
86
87        loop {
88            let mut block_reader = reader.take(self.block_size);
89
90            let (log_likelihood, posterior) = self.em.stream_e_step(sfs, &mut block_reader)?;
91            self.window.update(posterior);
92
93            sfs = self.window.sum().normalise();
94            log_likelihoods.push(log_likelihood);
95
96            sites += block_reader.sites_read();
97
98            if reader.is_done()? {
99                break;
100            }
101        }
102
103        Ok((log_likelihoods, sfs.scale(to_f64(sites))))
104    }
105}
106
107/// A window of block SFS estimates, used in window EM.
108///
109/// As part of the window EM algorithm, "windows" of block estimates are averaged out to give
110/// a running estimate of the SFS. The "window size" governs the number of past block estimates
111/// that are remembered and averaged over.
112#[derive(Clone, Debug, PartialEq)]
113pub struct Window<const D: usize> {
114    // Items are ordered old to new: oldest iterations are at the front, newest at the back
115    deque: VecDeque<USfs<D>>,
116}
117
118impl<const D: usize> Window<D> {
119    /// Creates a new window of with size `window_size` by repeating a provided SFS.
120    pub fn from_initial(initial: USfs<D>, window_size: usize) -> Self {
121        let deque = repeat(initial).take(window_size).collect();
122
123        Self { deque }
124    }
125
126    /// Creates a new window of zero-initialised SFS with size `window_size`.
127    pub fn from_zeros(shape: [usize; D], window_size: usize) -> Self {
128        Self::from_initial(USfs::zeros(shape), window_size)
129    }
130
131    /// Returns the shape of the window.
132    pub fn shape(&self) -> [usize; D] {
133        // We maintain as invariant that all items in deque have same shape,
134        // in order to make this okay
135        *(self.deque[0].shape())
136    }
137
138    /// Returns the sum of SFS in the window.
139    fn sum(&self) -> USfs<D> {
140        let first = USfs::zeros(self.shape());
141
142        self.deque.iter().fold(first, |sum, item| sum + item)
143    }
144
145    /// Updates the window after a new iteration of window EM.
146    ///
147    /// This corresponds to removing the oldest SFS from the window, and adding the new `sfs`.
148    fn update(&mut self, sfs: USfs<D>) {
149        if *sfs.shape() != self.shape() {
150            panic!("shape of provided SFS does not match shape of window")
151        }
152
153        let _old = self.deque.pop_front();
154        self.deque.push_back(sfs);
155    }
156
157    /// Returns the window size, corresponding to the number of past block estimates in the window.
158    pub fn window_size(&self) -> usize {
159        self.deque.len()
160    }
161}