winsfs_core/em/
window_em.rs1use std::{collections::VecDeque, io, iter::repeat};
2
3use crate::{
4 io::{Enumerate, ReadSite, Rewind, Take},
5 saf::{iter::IntoBlockIterator, AsSafView},
6 sfs::{Sfs, USfs},
7};
8
9use super::{to_f64, Em, EmSite, EmStep, StreamingEm};
10
11#[derive(Clone, Debug, PartialEq)]
18pub struct WindowEm<const D: usize, T> {
19 em: T,
20 window: Window<D>,
21 block_size: usize,
22}
23
24impl<const D: usize, T> WindowEm<D, T> {
25 pub fn new(em: T, window: Window<D>, block_size: usize) -> Self {
32 Self {
33 em,
34 window,
35 block_size,
36 }
37 }
38}
39
40impl<const D: usize, T> EmStep for WindowEm<D, T>
41where
42 T: EmStep,
43{
44 type Status = Vec<T::Status>;
45}
46
47impl<const D: usize, I, T> Em<D, I> for WindowEm<D, T>
48where
49 for<'a> &'a I: IntoBlockIterator<D>,
50 for<'a> T: Em<D, <&'a I as IntoBlockIterator<D>>::Item>,
51{
52 fn e_step(&mut self, mut sfs: Sfs<D>, input: &I) -> (Self::Status, USfs<D>) {
53 let mut log_likelihoods = Vec::with_capacity(self.block_size);
54
55 let mut sites = 0;
56
57 for block in input.into_block_iter(self.block_size) {
58 sites += block.as_saf_view().sites();
59
60 let (log_likelihood, posterior) = self.em.e_step(sfs, &block);
61
62 self.window.update(posterior);
63
64 sfs = self.window.sum().normalise();
65 log_likelihoods.push(log_likelihood);
66 }
67
68 (log_likelihoods, sfs.scale(to_f64(sites)))
69 }
70}
71
72impl<const D: usize, R, T> StreamingEm<D, R> for WindowEm<D, T>
73where
74 R: Rewind,
75 R::Site: EmSite<D>,
76 for<'a> T: StreamingEm<D, Take<Enumerate<&'a mut R>>>,
77{
78 fn stream_e_step(
79 &mut self,
80 mut sfs: Sfs<D>,
81 reader: &mut R,
82 ) -> io::Result<(Self::Status, USfs<D>)> {
83 let mut log_likelihoods = Vec::with_capacity(self.block_size);
84
85 let mut sites = 0;
86
87 loop {
88 let mut block_reader = reader.take(self.block_size);
89
90 let (log_likelihood, posterior) = self.em.stream_e_step(sfs, &mut block_reader)?;
91 self.window.update(posterior);
92
93 sfs = self.window.sum().normalise();
94 log_likelihoods.push(log_likelihood);
95
96 sites += block_reader.sites_read();
97
98 if reader.is_done()? {
99 break;
100 }
101 }
102
103 Ok((log_likelihoods, sfs.scale(to_f64(sites))))
104 }
105}
106
107#[derive(Clone, Debug, PartialEq)]
113pub struct Window<const D: usize> {
114 deque: VecDeque<USfs<D>>,
116}
117
118impl<const D: usize> Window<D> {
119 pub fn from_initial(initial: USfs<D>, window_size: usize) -> Self {
121 let deque = repeat(initial).take(window_size).collect();
122
123 Self { deque }
124 }
125
126 pub fn from_zeros(shape: [usize; D], window_size: usize) -> Self {
128 Self::from_initial(USfs::zeros(shape), window_size)
129 }
130
131 pub fn shape(&self) -> [usize; D] {
133 *(self.deque[0].shape())
136 }
137
138 fn sum(&self) -> USfs<D> {
140 let first = USfs::zeros(self.shape());
141
142 self.deque.iter().fold(first, |sum, item| sum + item)
143 }
144
145 fn update(&mut self, sfs: USfs<D>) {
149 if *sfs.shape() != self.shape() {
150 panic!("shape of provided SFS does not match shape of window")
151 }
152
153 let _old = self.deque.pop_front();
154 self.deque.push_back(sfs);
155 }
156
157 pub fn window_size(&self) -> usize {
159 self.deque.len()
160 }
161}