winsfs_core/em/
adaptors.rs

1use crate::{
2    io::Rewind,
3    sfs::{Sfs, USfs},
4};
5
6use super::{
7    stopping::{Stop, StoppingRule},
8    Em, EmSite, EmStep, StreamingEm,
9};
10
11/// A combinator for types that allows inspection after each E-step.
12///
13/// To inspect the EM process itself, this can be constructed using [`EmStep::inspect`],
14/// which is available on anything implementing [`Em`] or [`StreamingEm`]. To inspect a
15/// stopping rule, this can be constructed using [`StoppingRule::inspect`].
16#[derive(Debug)]
17pub struct Inspect<T, F> {
18    inner: T,
19    f: F,
20}
21
22impl<T, F> Inspect<T, F> {
23    pub(super) fn new(inner: T, f: F) -> Self {
24        Self { inner, f }
25    }
26}
27
28impl<T, F> EmStep for Inspect<T, F>
29where
30    T: EmStep,
31{
32    type Status = T::Status;
33}
34
35impl<const D: usize, T, F, I> Em<D, I> for Inspect<T, F>
36where
37    T: Em<D, I>,
38    F: FnMut(&T, &T::Status, &USfs<D>),
39{
40    fn e_step(&mut self, sfs: Sfs<D>, input: &I) -> (Self::Status, USfs<D>) {
41        let (status, sfs) = self.inner.e_step(sfs, input);
42
43        (self.f)(&self.inner, &status, &sfs);
44
45        (status, sfs)
46    }
47}
48
49impl<const D: usize, T, F, R> StreamingEm<D, R> for Inspect<T, F>
50where
51    R: Rewind,
52    R::Site: EmSite<D>,
53    T: StreamingEm<D, R>,
54    F: FnMut(&T, &T::Status, &USfs<D>),
55{
56    fn stream_e_step(
57        &mut self,
58        sfs: Sfs<D>,
59        reader: &mut R,
60    ) -> std::io::Result<(Self::Status, USfs<D>)> {
61        let (status, sfs) = self.inner.stream_e_step(sfs, reader)?;
62
63        (self.f)(&self.inner, &status, &sfs);
64
65        Ok((status, sfs))
66    }
67}
68
69impl<S, F> StoppingRule for Inspect<S, F> where S: StoppingRule {}
70
71impl<T, S, F> Stop<T> for Inspect<S, F>
72where
73    T: EmStep,
74    S: Stop<T, Status = T::Status>,
75    F: FnMut(&S),
76{
77    type Status = T::Status;
78
79    fn stop<const D: usize>(&mut self, em: &T, status: &Self::Status, sfs: &Sfs<D>) -> bool {
80        (self.f)(&self.inner);
81
82        self.inner.stop(em, status, sfs)
83    }
84}