trellis_runner/state/
mod.rs

1mod status;
2
3use crate::TrellisFloat;
4
5use num_traits::float::FloatCore;
6use web_time::Duration;
7
8pub use status::{Cause, Status};
9
10#[derive(Clone, Debug)]
11pub enum UpdateData<T> {
12    // The update can return an estimate of the error
13    ErrorEstimate { relative: T, absolute: T },
14    // Some calculations do not track an error estimate, this means they converge through a
15    // different metric. In this case the user needs to tell trellis convergence has been achieved
16    Complete,
17}
18
19/// A simple wrapper for error estimates that can be converted to UpdateData
20#[derive(Clone, Debug)]
21pub struct ErrorEstimate<T>(pub T);
22
23impl<T: Clone> From<ErrorEstimate<T>> for Option<UpdateData<T>> {
24    fn from(estimate: ErrorEstimate<T>) -> Self {
25        Some(UpdateData::ErrorEstimate {
26            relative: estimate.0.clone(),
27            absolute: estimate.0,
28        })
29    }
30}
31
32/// The user-defined state must implement this trait to be used as part of the trellis calculation
33/// loop
34///
35/// All other state methods are auto-implemented on a type wrapping the user-defined state.
36pub trait UserState {
37    type Float: TrellisFloat;
38    type Param;
39
40    /// Create a new instance of the user-defined state object
41    fn new() -> Self;
42
43    // Returns true when the state object is initialised correctly
44    fn is_initialised(&self) -> bool {
45        true
46    }
47    // Update the state object at the end of an iteration
48    //
49    // The update method can be used to control convergence:
50    // - By returning an [`UpdateData::ErrorEstimate`] the error estimate will be compared to the
51    //  solver's absolute and relative tolerances. Termination will happen automatically when these
52    //  conditions are satisfied.
53    // - By returning [`UpdateData::Complete`] the solver will terminate immediately
54    // - By returning [`None`] the solver will continue until max iterations
55    fn update(&mut self) -> impl Into<Option<UpdateData<Self::Float>>>;
56    // Returns the current parameter value, if one is assigned
57    fn get_param(&self) -> Option<&Self::Param>;
58    // Returns true if the last iteration was the best iteration seen so far
59    fn last_was_best(&mut self);
60}
61
62/// The state of the [`trellis`] solver
63///
64/// This contains generic fields common to all solvers, as well as a user-defined state
65/// `S` which contains application specific fields.
66pub struct State<S: UserState> {
67    /// The specific component of the state implements the application specific code
68    specific: Option<S>,
69    /// The current iteration number
70    iter: usize,
71    /// The last iteration number where the smallest error estimate was found
72    last_best_iter: usize,
73    /// The maximum number of permitted iterations
74    max_iter: usize,
75    /// The time since the solver was instantiated
76    time: Option<Duration>,
77    /// The termination status of the solver
78    pub(crate) termination_status: Status,
79    /// The current estimate of the error, that observed in the previous iteration
80    ///
81    /// Note that all stored error values are absolute, to prevent issues at low result values
82    error: S::Float,
83    /// The estimate of the error observed in the one before last iteration
84    prev_error: S::Float,
85    /// The best value of the error observed during the entire calculation
86    best_error: S::Float,
87    /// The second best value of the error observed during the entire calculation
88    prev_best_error: S::Float,
89    /// The target relative tolerance
90    relative_tolerance: S::Float,
91    /// The target relative tolerance
92    absolute_tolerance: S::Float,
93}
94
95impl<S> State<S>
96where
97    S: UserState,
98    <S as UserState>::Float: FloatCore,
99{
100    /// Create a new instance of the iteration state
101    pub(crate) fn new() -> Self {
102        Self {
103            specific: Some(S::new()),
104            iter: 0,
105            last_best_iter: 0,
106            max_iter: usize::MAX,
107            termination_status: Status::NotTerminated,
108            time: None,
109            relative_tolerance: <<S as UserState>::Float as FloatCore>::epsilon(),
110            absolute_tolerance: <<S as UserState>::Float as FloatCore>::epsilon(),
111            error: <<S as UserState>::Float as FloatCore>::infinity(),
112            prev_error: <<S as UserState>::Float as FloatCore>::infinity(),
113            best_error: <<S as UserState>::Float as FloatCore>::infinity(),
114            prev_best_error: <<S as UserState>::Float as FloatCore>::infinity(),
115        }
116    }
117
118    /// Record the time since the solver began
119    pub(crate) fn record_time(&mut self, duration: Duration) {
120        self.time = Some(duration);
121    }
122
123    pub(crate) fn duration(&self) -> Option<&Duration> {
124        self.time.as_ref()
125    }
126
127    /// Increment the iteration count
128    pub(crate) fn increment_iteration(&mut self) {
129        self.iter += 1;
130    }
131
132    /// Returns the current iteration number
133    pub(crate) fn current_iteration(&self) -> usize {
134        self.iter
135    }
136
137    /// Returns the number of iterations since the best result was observed
138    pub(crate) fn iterations_since_best(&self) -> usize {
139        self.iter - self.last_best_iter
140    }
141    /// Returns true if the state has been initialised. This means a problem specific inner solver
142    /// has been attached
143    pub(crate) fn is_initialised(&self) -> bool {
144        self.specific
145            .as_ref()
146            .is_some_and(|state| state.is_initialised())
147    }
148
149    /// Returns true if the termination status is [`Status::Terminated`]
150    pub(crate) fn is_terminated(&self) -> bool {
151        self.termination_status != Status::NotTerminated
152    }
153
154    /// Terminates the solver for [`Cause`]
155    pub fn terminate_due_to(mut self, reason: Cause) -> Self {
156        self.termination_status = Status::Terminated(reason);
157        self
158    }
159
160    /// Returns Some if the solver is terminated, else returns None
161    pub(crate) fn termination_cause(&self) -> Option<&Cause> {
162        use Status::*;
163        match &self.termination_status {
164            NotTerminated => None,
165            Terminated(cause) => Some(cause),
166        }
167    }
168
169    #[must_use]
170    /// Update the state, and the interan state
171    pub(crate) fn update(mut self) -> Self {
172        let mut specific = self.specific.take().unwrap();
173        match specific.update().into() {
174            // If an error estimate was provided update the internal state accordingly
175            Some(UpdateData::ErrorEstimate { absolute, .. }) => {
176                self.error = absolute;
177                if self.error < self.best_error
178                    || (FloatCore::is_infinite(self.error)
179                        && FloatCore::is_infinite(self.best_error)
180                        && FloatCore::is_sign_positive(self.error)
181                            == FloatCore::is_sign_positive(self.best_error))
182                {
183                    std::mem::swap(&mut self.prev_best_error, &mut self.best_error);
184                    self.best_error = self.error;
185                    self.last_best_iter = self.iter;
186
187                    specific.last_was_best();
188                }
189            }
190            // If the calculation completed successfully return
191            Some(UpdateData::Complete) => {
192                return self
193                    .set_specific(specific)
194                    .terminate_due_to(Cause::Converged);
195            }
196            _ => (),
197        };
198
199        self = self.set_specific(specific);
200
201        if self.error < self.absolute_tolerance {
202            return self.terminate_due_to(Cause::Converged);
203        }
204        if self.current_iteration() > self.max_iter {
205            return self.terminate_due_to(Cause::ExceededMaxIterations);
206        }
207
208        self
209    }
210
211    /// Returns the parameter vector from the inner state variable
212    pub(crate) fn get_param(&self) -> Option<&S::Param> {
213        self.specific
214            .as_ref()
215            .and_then(|specific| specific.get_param())
216    }
217
218    /// Returns the current measure of progress
219    pub(crate) fn measure(&self) -> S::Float {
220        self.error
221    }
222
223    /// Returns the best measure of progress
224    pub(crate) fn best_measure(&self) -> S::Float {
225        self.best_error
226    }
227
228    /// Removes the specific state from the state and returns it
229    pub fn take_specific(&mut self) -> S {
230        self.specific.take().unwrap()
231    }
232
233    #[must_use]
234    /// Set the relative tolerance target
235    pub fn relative_tolerance(mut self, relative_tolerance: S::Float) -> Self {
236        self.relative_tolerance = relative_tolerance;
237        self
238    }
239
240    #[must_use]
241    /// Set the relative tolerance target
242    pub fn absolute_tolerance(mut self, absolute_tolerance: S::Float) -> Self {
243        self.absolute_tolerance = absolute_tolerance;
244        self
245    }
246
247    #[must_use]
248    /// Set the maximum allowable iteration count
249    pub fn max_iters(mut self, max_iter: usize) -> Self {
250        self.max_iter = max_iter;
251        self
252    }
253
254    #[must_use]
255    /// Set the internal state object
256    pub fn set_specific(mut self, specific: S) -> Self {
257        self.specific = Some(specific);
258        self
259    }
260}