winsfs_core/em/
stopping.rs

1//! Stopping rules used for deciding convergence for EM algorithms.
2
3use crate::sfs::Sfs;
4
5use super::{
6    likelihood::{LogLikelihood, SumOf},
7    EmStep, Inspect,
8};
9
10/// A type that can be combined in various ways to create a stopping rule for EM types.
11///
12/// This primary serves as a supertrait bound for [`Stop`]. [`Stop`] cannot provide these
13/// methods directly without breaking type inference when calling them.
14pub trait StoppingRule {
15    /// Returns a new stopping rule that requires that *both* this *and* another stopping
16    /// rule must indicicate convergence before stopping.
17    fn and<S>(self, other: S) -> Both<Self, S>
18    where
19        Self: Sized,
20    {
21        Both::new(self, other)
22    }
23
24    /// Inspect the stopping rule after each E-step.
25    ///
26    /// This can only be used for inspecting the state of the stopping rule itself. To inspect
27    /// other aspects of the algorithm, see [`EmStep::inspect`].
28    fn inspect<F>(self, f: F) -> Inspect<Self, F>
29    where
30        Self: Sized,
31        F: FnMut(&Self),
32    {
33        Inspect::new(self, f)
34    }
35
36    /// Returns a new stopping rule that requires that *either* this *or* another stopping
37    /// rule must indicicate convergence before stopping.
38    fn or<S>(self, other: S) -> Either<Self, S>
39    where
40        Self: Sized,
41    {
42        Either::new(self, other)
43    }
44}
45
46/// A type capable of deciding whether an EM algorithm should stop.
47pub trait Stop<T>: StoppingRule {
48    /// A status from the E-step that may be used for checking convergence.
49    type Status;
50
51    /// Returns `true` if the algorithm should stop, `false` otherwise.
52    fn stop<const N: usize>(&mut self, em: &T, status: &Self::Status, sfs: &Sfs<N>) -> bool;
53}
54
55/// A stopping rule that lets the EM algorithm run for a specific number of EM-steps.
56pub struct Steps {
57    current_step: usize,
58    max_steps: usize,
59}
60
61impl Steps {
62    /// Returns the current step.
63    pub fn current_step(&self) -> usize {
64        self.current_step
65    }
66
67    /// Returns the total number of steps before stopping.
68    pub fn steps(&self) -> usize {
69        self.max_steps
70    }
71
72    /// Creates a new stopping rule that allows `steps` EM steps.
73    pub fn new(steps: usize) -> Self {
74        Self {
75            current_step: 0,
76            max_steps: steps,
77        }
78    }
79}
80
81impl StoppingRule for Steps {}
82
83impl<T> Stop<T> for Steps
84where
85    T: EmStep,
86{
87    type Status = T::Status;
88
89    fn stop<const N: usize>(&mut self, _em: &T, _status: &Self::Status, _sfs: &Sfs<N>) -> bool {
90        self.current_step += 1;
91        self.current_step >= self.max_steps
92    }
93}
94
95/// A stopping rule that lets the EM algorithm run until the absolute difference in successive,
96/// normalised log-likelihood values falls below some tolerance.
97///
98/// The log-likelihood will be normalised by the number of sites, so that it becomes a per-site
99/// measure. This makes it easier to find a reasonable tolerance for a range of input sizes.
100pub struct LogLikelihoodTolerance {
101    abs_diff: f64,
102    log_likelihood: f64,
103    tolerance: f64,
104}
105
106impl LogLikelihoodTolerance {
107    /// Returns the absolute difference between the two most recent normalised log-likelihood values.
108    pub fn absolute_difference(&self) -> f64 {
109        self.abs_diff
110    }
111
112    /// Returns the current, normalised log-likelihood value.
113    pub fn log_likelihood(&self) -> LogLikelihood {
114        self.log_likelihood.into()
115    }
116
117    /// Creates a new stopping rule that allows EM steps until the absolute difference in successive,
118    /// normalised log-likelihood values falls below `tolerance`.
119    pub fn new(tolerance: f64) -> Self {
120        Self {
121            abs_diff: f64::INFINITY,
122            log_likelihood: f64::NEG_INFINITY,
123            tolerance,
124        }
125    }
126
127    /// Returns the tolerance defining convergence.
128    pub fn tolerance(&self) -> f64 {
129        self.tolerance
130    }
131
132    /// Provides the implementation of `stop`, shared between `LogLikelihoodTolerance`
133    /// and `WindowLogLikelihoodTolerance`.
134    fn stop_inner(&mut self, new_log_likelihood: f64) -> bool {
135        self.abs_diff = (new_log_likelihood - self.log_likelihood).abs();
136        self.log_likelihood = new_log_likelihood;
137
138        self.abs_diff <= self.tolerance
139    }
140}
141
142impl StoppingRule for LogLikelihoodTolerance {}
143
144impl<T> Stop<T> for LogLikelihoodTolerance
145where
146    T: EmStep<Status = SumOf<LogLikelihood>>,
147{
148    type Status = T::Status;
149
150    fn stop<const N: usize>(&mut self, _em: &T, status: &Self::Status, _sfs: &Sfs<N>) -> bool {
151        let new_log_likelihood = status.normalise();
152
153        self.stop_inner(new_log_likelihood)
154    }
155}
156
157/// A stopping rule for window EM that lets the algorithm run until the successive sum of
158/// normalised block log-likelihood values falls below a certain tolerance.
159///
160/// This is analogous to [`LogLikelihoodTolerance`], but instead of considering the full
161/// (normalised) data log-likelihood, we consider the sum of these values over blocks.
162pub struct WindowLogLikelihoodTolerance {
163    inner: LogLikelihoodTolerance,
164}
165
166impl WindowLogLikelihoodTolerance {
167    /// Returns the absolute difference between the two most recent window log-likelihood values.
168    pub fn absolute_difference(&self) -> f64 {
169        self.inner.absolute_difference()
170    }
171
172    /// Returns the current window log-likelihood value.
173    pub fn log_likelihood(&self) -> LogLikelihood {
174        self.inner.log_likelihood()
175    }
176
177    /// Creates a new stopping rule that allows EM steps until the absolute difference in successive,
178    /// window log-likelihood values falls below `tolerance`.
179    pub fn new(tolerance: f64) -> Self {
180        Self {
181            inner: LogLikelihoodTolerance::new(tolerance),
182        }
183    }
184
185    /// Returns the tolerance defining convergence.
186    pub fn tolerance(&self) -> f64 {
187        self.inner.tolerance()
188    }
189}
190
191impl StoppingRule for WindowLogLikelihoodTolerance {}
192
193impl<T> Stop<T> for WindowLogLikelihoodTolerance
194where
195    T: EmStep<Status = Vec<SumOf<LogLikelihood>>>,
196{
197    type Status = T::Status;
198
199    fn stop<const N: usize>(&mut self, _em: &T, status: &Self::Status, _sfs: &Sfs<N>) -> bool {
200        let new_log_likelihood = status
201            .iter()
202            .map(|block_log_likelihood| block_log_likelihood.normalise())
203            .sum();
204
205        self.inner.stop_inner(new_log_likelihood)
206    }
207}
208
209/// A stopping rule that lets the EM algorithm run until *both* the contained stopping rules
210/// indicate convergence.
211///
212/// Typically constructed using [`StoppingRule::and`].
213pub struct Both<A, B> {
214    left: A,
215    right: B,
216}
217
218impl<A, B> Both<A, B> {
219    /// Returns a new stopping rule.
220    fn new(left: A, right: B) -> Self {
221        Self { left, right }
222    }
223
224    /// Returns the "left" stopping rule.
225    pub fn left(&self) -> &A {
226        &self.left
227    }
228
229    /// Returns the "right" stopping rule.
230    pub fn right(&self) -> &B {
231        &self.right
232    }
233}
234
235impl<A, B> StoppingRule for Both<A, B>
236where
237    A: StoppingRule,
238    B: StoppingRule,
239{
240}
241
242impl<T, A, B> Stop<T> for Both<A, B>
243where
244    T: EmStep,
245    A: Stop<T, Status = T::Status>,
246    B: Stop<T, Status = T::Status>,
247{
248    type Status = T::Status;
249
250    fn stop<const N: usize>(&mut self, em: &T, status: &Self::Status, sfs: &Sfs<N>) -> bool {
251        self.left.stop(em, status, sfs) && self.right.stop(em, status, sfs)
252    }
253}
254
255/// A stopping rule that lets the EM algorithm run until *either* of the contained stopping rules
256/// indicate convergence.
257///
258/// Typically constructed using [`StoppingRule::or`].
259pub struct Either<A, B> {
260    left: A,
261    right: B,
262}
263
264impl<A, B> Either<A, B> {
265    /// Returns a new stopping rule.
266    fn new(left: A, right: B) -> Self {
267        Self { left, right }
268    }
269
270    /// Returns the "left" stopping rule.
271    pub fn left(&self) -> &A {
272        &self.left
273    }
274
275    /// Returns the "right" stopping rule.
276    pub fn right(&self) -> &B {
277        &self.right
278    }
279}
280
281impl<A, B> StoppingRule for Either<A, B>
282where
283    A: StoppingRule,
284    B: StoppingRule,
285{
286}
287
288impl<T, A, B> Stop<T> for Either<A, B>
289where
290    T: EmStep,
291    A: Stop<T, Status = T::Status>,
292    B: Stop<T, Status = T::Status>,
293{
294    type Status = T::Status;
295
296    fn stop<const N: usize>(&mut self, em: &T, status: &Self::Status, sfs: &Sfs<N>) -> bool {
297        self.left.stop(em, status, sfs) || self.right.stop(em, status, sfs)
298    }
299}