trellis_runner/state/
mod.rs1mod status;
2
3use crate::TrellisFloat;
4
5use num_traits::float::FloatCore;
6use web_time::Duration;
7
8pub use status::{Cause, Status};
9
10#[derive(Clone, Debug)]
11pub enum UpdateData<T> {
12 ErrorEstimate { relative: T, absolute: T },
14 Complete,
17}
18
19#[derive(Clone, Debug)]
21pub struct ErrorEstimate<T>(pub T);
22
23impl<T: Clone> From<ErrorEstimate<T>> for Option<UpdateData<T>> {
24 fn from(estimate: ErrorEstimate<T>) -> Self {
25 Some(UpdateData::ErrorEstimate {
26 relative: estimate.0.clone(),
27 absolute: estimate.0,
28 })
29 }
30}
31
32pub trait UserState {
37 type Float: TrellisFloat;
38 type Param;
39
40 fn new() -> Self;
42
43 fn is_initialised(&self) -> bool {
45 true
46 }
47 fn update(&mut self) -> impl Into<Option<UpdateData<Self::Float>>>;
56 fn get_param(&self) -> Option<&Self::Param>;
58 fn last_was_best(&mut self);
60}
61
62pub struct State<S: UserState> {
67 specific: Option<S>,
69 iter: usize,
71 last_best_iter: usize,
73 max_iter: usize,
75 time: Option<Duration>,
77 pub(crate) termination_status: Status,
79 error: S::Float,
83 prev_error: S::Float,
85 best_error: S::Float,
87 prev_best_error: S::Float,
89 relative_tolerance: S::Float,
91 absolute_tolerance: S::Float,
93}
94
95impl<S> State<S>
96where
97 S: UserState,
98 <S as UserState>::Float: FloatCore,
99{
100 pub(crate) fn new() -> Self {
102 Self {
103 specific: Some(S::new()),
104 iter: 0,
105 last_best_iter: 0,
106 max_iter: usize::MAX,
107 termination_status: Status::NotTerminated,
108 time: None,
109 relative_tolerance: <<S as UserState>::Float as FloatCore>::epsilon(),
110 absolute_tolerance: <<S as UserState>::Float as FloatCore>::epsilon(),
111 error: <<S as UserState>::Float as FloatCore>::infinity(),
112 prev_error: <<S as UserState>::Float as FloatCore>::infinity(),
113 best_error: <<S as UserState>::Float as FloatCore>::infinity(),
114 prev_best_error: <<S as UserState>::Float as FloatCore>::infinity(),
115 }
116 }
117
118 pub(crate) fn record_time(&mut self, duration: Duration) {
120 self.time = Some(duration);
121 }
122
123 pub(crate) fn duration(&self) -> Option<&Duration> {
124 self.time.as_ref()
125 }
126
127 pub(crate) fn increment_iteration(&mut self) {
129 self.iter += 1;
130 }
131
132 pub(crate) fn current_iteration(&self) -> usize {
134 self.iter
135 }
136
137 pub(crate) fn iterations_since_best(&self) -> usize {
139 self.iter - self.last_best_iter
140 }
141 pub(crate) fn is_initialised(&self) -> bool {
144 self.specific
145 .as_ref()
146 .is_some_and(|state| state.is_initialised())
147 }
148
149 pub(crate) fn is_terminated(&self) -> bool {
151 self.termination_status != Status::NotTerminated
152 }
153
154 pub fn terminate_due_to(mut self, reason: Cause) -> Self {
156 self.termination_status = Status::Terminated(reason);
157 self
158 }
159
160 pub(crate) fn termination_cause(&self) -> Option<&Cause> {
162 use Status::*;
163 match &self.termination_status {
164 NotTerminated => None,
165 Terminated(cause) => Some(cause),
166 }
167 }
168
169 #[must_use]
170 pub(crate) fn update(mut self) -> Self {
172 let mut specific = self.specific.take().unwrap();
173 match specific.update().into() {
174 Some(UpdateData::ErrorEstimate { absolute, .. }) => {
176 self.error = absolute;
177 if self.error < self.best_error
178 || (FloatCore::is_infinite(self.error)
179 && FloatCore::is_infinite(self.best_error)
180 && FloatCore::is_sign_positive(self.error)
181 == FloatCore::is_sign_positive(self.best_error))
182 {
183 std::mem::swap(&mut self.prev_best_error, &mut self.best_error);
184 self.best_error = self.error;
185 self.last_best_iter = self.iter;
186
187 specific.last_was_best();
188 }
189 }
190 Some(UpdateData::Complete) => {
192 return self
193 .set_specific(specific)
194 .terminate_due_to(Cause::Converged);
195 }
196 _ => (),
197 };
198
199 self = self.set_specific(specific);
200
201 if self.error < self.absolute_tolerance {
202 return self.terminate_due_to(Cause::Converged);
203 }
204 if self.current_iteration() > self.max_iter {
205 return self.terminate_due_to(Cause::ExceededMaxIterations);
206 }
207
208 self
209 }
210
211 pub(crate) fn get_param(&self) -> Option<&S::Param> {
213 self.specific
214 .as_ref()
215 .and_then(|specific| specific.get_param())
216 }
217
218 pub(crate) fn measure(&self) -> S::Float {
220 self.error
221 }
222
223 pub(crate) fn best_measure(&self) -> S::Float {
225 self.best_error
226 }
227
228 pub fn take_specific(&mut self) -> S {
230 self.specific.take().unwrap()
231 }
232
233 #[must_use]
234 pub fn relative_tolerance(mut self, relative_tolerance: S::Float) -> Self {
236 self.relative_tolerance = relative_tolerance;
237 self
238 }
239
240 #[must_use]
241 pub fn absolute_tolerance(mut self, absolute_tolerance: S::Float) -> Self {
243 self.absolute_tolerance = absolute_tolerance;
244 self
245 }
246
247 #[must_use]
248 pub fn max_iters(mut self, max_iter: usize) -> Self {
250 self.max_iter = max_iter;
251 self
252 }
253
254 #[must_use]
255 pub fn set_specific(mut self, specific: S) -> Self {
257 self.specific = Some(specific);
258 self
259 }
260}