1use num_traits::float::FloatCore;
87use std::sync::{Arc, Mutex};
88use tokio_util::sync::CancellationToken;
89
90use crate::engine::policy::{CancellationPolicy, CompletionPolicy, EnginePolicy, PolicyStack};
91use crate::{
92 engine::{
93 checkpoint::{CheckpointBackend, CheckpointExtension},
94 extensions::Extensions,
95 Engine,
96 },
97 state::{Snapshotable, State, StateRestorer},
98 watchers::{Frequency, Observe, Observers},
99 FallibleProcedure, Infallible, Procedure, UserState,
100};
101
102pub trait GenerateBuilderFallible: Sized {
103 fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
104 where
105 Self: FallibleProcedure<P>,
106 Self::State: UserState;
107}
108
109impl<Proc> GenerateBuilderFallible for Proc {
110 fn build_for<P>(self, problem: P) -> Builder<Self, P, Uninitialised>
111 where
112 Proc: FallibleProcedure<P>,
113 Proc::State: UserState,
114 {
115 Builder {
116 procedure: self,
117 problem,
118 state: None,
119 time: true,
120 cancellation_token: None,
121
122 observers: Observers::new(),
123
124 policies: PolicyStack::new()
125 .add(CancellationPolicy)
126 .add(CompletionPolicy),
127
128 extensions: Extensions::new(),
129
130 _initialised: std::marker::PhantomData,
131 }
132 }
133}
134
135pub trait GenerateBuilder: Sized {
136 fn build_for<P>(self, problem: P) -> Builder<Infallible<Self>, P, Uninitialised>
137 where
138 Self: Procedure<P>,
139 Self::State: UserState;
140}
141
142impl<Proc> GenerateBuilder for Proc {
143 fn build_for<P>(self, problem: P) -> Builder<Infallible<Self>, P, Uninitialised>
144 where
145 Proc: Procedure<P>,
146 Proc::State: UserState,
147 {
148 Builder {
149 procedure: Infallible(self),
150 problem,
151 state: None,
152 time: true,
153 cancellation_token: None,
154
155 observers: Observers::new(),
156
157 policies: PolicyStack::new()
158 .add(CancellationPolicy)
159 .add(CompletionPolicy),
160
161 extensions: Extensions::new(),
162
163 _initialised: std::marker::PhantomData,
164 }
165 }
166}
167
168pub struct Uninitialised;
169pub struct Initialised;
170
171pub struct Builder<Proc, P, I>
172where
173 Proc: FallibleProcedure<P>,
174 Proc::State: UserState,
175 <Proc::State as UserState>::Float: FloatCore,
176{
177 procedure: Proc,
178 problem: P,
179 state: Option<Proc::State>,
180 time: bool,
181 cancellation_token: Option<CancellationToken>,
182
183 observers: Observers<Proc::State>,
184
185 policies: PolicyStack<<Proc::State as UserState>::Float>,
186 extensions: Extensions<Proc::State>,
187
188 _initialised: std::marker::PhantomData<I>,
189}
190
191impl<Proc, P, I> Builder<Proc, P, I>
192where
193 Proc: FallibleProcedure<P>,
194 Proc::State: UserState,
195 <Proc::State as UserState>::Float: FloatCore + 'static,
196{
197 #[must_use]
198 pub fn time(mut self, time: bool) -> Self {
199 self.time = time;
200 self
201 }
202
203 #[must_use]
205 pub fn attach_observer<OBS>(mut self, observer: OBS, frequency: Frequency) -> Self
206 where
207 OBS: Observe<Proc::State> + 'static,
208 {
209 self.observers
210 .attach(Arc::new(Mutex::new(observer)), frequency);
211 self
212 }
213
214 #[must_use]
215 pub fn and_policy<Q>(mut self, policy: Q) -> Self
216 where
217 Q: EnginePolicy<<Proc::State as UserState>::Float> + 'static,
218 {
219 self.policies = self.policies.add(policy);
220 self
221 }
222
223 #[must_use]
224 pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
225 self.cancellation_token = Some(token);
226 self
227 }
228
229 #[must_use]
230 pub fn with_default_policies(
234 mut self,
235 max_iter: usize,
236 absolute_tolerance: <Proc::State as UserState>::Float,
237 window_size: usize,
238 ) -> Self {
239 self.policies = self.policies.merge(PolicyStack::standard(
240 max_iter,
241 absolute_tolerance,
242 window_size,
243 ));
244 self
245 }
246
247 #[must_use]
248 pub fn with_checkpoint_backend<C>(mut self, store: C) -> Self
254 where
255 C: CheckpointBackend<
256 <Proc::State as Snapshotable>::Snapshot,
257 <Proc::State as UserState>::Float,
258 > + 'static,
259 Proc::State: Snapshotable,
260 {
261 self.extensions = self.extensions.add(CheckpointExtension::new(store));
262 self
263 }
264}
265
266impl<Proc, P> Builder<Proc, P, Uninitialised>
267where
268 Proc: FallibleProcedure<P>,
269 Proc::State: UserState,
270 <Proc::State as UserState>::Float: FloatCore + 'static,
271{
272 #[must_use]
274 pub fn with_initial_state(self, user: Proc::State) -> Builder<Proc, P, Initialised> {
275 Builder {
276 procedure: self.procedure,
277 problem: self.problem,
278 state: Some(user),
279 time: self.time,
280 cancellation_token: self.cancellation_token,
281
282 observers: self.observers,
283
284 policies: self.policies,
285
286 extensions: self.extensions,
287
288 _initialised: std::marker::PhantomData,
289 }
290 }
291
292 #[must_use]
293 pub fn resume_from_checkpoint(
294 self,
295 snapshot: <Proc::State as Snapshotable>::Snapshot,
296 ) -> Builder<Proc, P, Initialised>
297 where
298 Proc: FallibleProcedure<P>,
299 Proc::State: Snapshotable + StateRestorer<Proc::State>,
300 {
301 let user = Proc::State::restore(snapshot);
302
303 Builder {
304 procedure: self.procedure,
305 problem: self.problem,
306 state: Some(user),
307 time: self.time,
308 cancellation_token: self.cancellation_token,
309
310 observers: self.observers,
311
312 policies: self.policies,
313
314 extensions: self.extensions,
315
316 _initialised: std::marker::PhantomData,
317 }
318 }
319}
320
321impl<Proc, P> Builder<Proc, P, Initialised>
322where
323 Proc: FallibleProcedure<P>,
324 Proc::State: UserState,
325 <Proc::State as UserState>::Float: FloatCore + 'static,
326{
327 pub fn finalise(mut self) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>>
332 where
333 <Proc::State as UserState>::Float: num_traits::FromPrimitive,
334 {
335 let user = self.state.take().expect("builder invariant: user is set");
336
337 let cancellation = self.cancellation_token.unwrap_or_default();
338
339 #[cfg(feature = "ctrlc")]
340 {
341 let token = cancellation.clone();
342 ctrlc::set_handler(move || {
343 token.cancel();
344 })
345 .unwrap();
346 }
347
348 Engine {
349 procedure: self.procedure,
350 problem: self.problem,
351 state: State::new(user),
352
353 time: self.time,
354 start_time: None,
355
356 cancellation,
357
358 policy: self.policies,
359
360 observers: self.observers,
361 extensions: self.extensions,
362 }
363 }
364
365 pub fn finalise_with(
373 mut self,
374 policy: PolicyStack<<Proc::State as UserState>::Float>,
375 ) -> Engine<Proc, P, PolicyStack<<Proc::State as UserState>::Float>> {
376 let user = self.state.take().expect("builder invariant: user is set");
377 let cancellation = self.cancellation_token.unwrap_or_default();
378
379 #[cfg(feature = "ctrlc")]
380 {
381 let token = cancellation.clone();
382 ctrlc::set_handler(move || {
383 token.cancel();
384 })
385 .unwrap();
386 }
387
388 Engine {
389 procedure: self.procedure,
390 problem: self.problem,
391 state: State::new(user),
392
393 time: self.time,
394 start_time: None,
395
396 cancellation,
397
398 policy,
399
400 observers: self.observers,
401 extensions: self.extensions,
402 }
403 }
404}
405