twine_core/simulation.rs
1use std::{iter::FusedIterator, time::Duration};
2
3use crate::model::Model;
4
5/// Trait for defining a transient simulation in Twine.
6///
7/// A `Simulation` advances a [`Model`] forward in time by computing its next
8/// input with [`advance_time`] and then calling the model to produce a new
9/// [`State`] representing the system at the corresponding future moment.
10///
11/// # Stepping Methods
12///
13/// After implementing [`advance_time`], the following methods are available for
14/// advancing the simulation:
15///
16/// - [`Simulation::step`]: Takes a single step from an initial input.
17/// - [`Simulation::step_from_state`]: Takes a single step from a known state.
18/// - [`Simulation::step_many`]: Takes multiple steps and collects all resulting states.
19/// - [`Simulation::into_step_iter`]: Consumes the simulation and returns an iterator over its steps.
20pub trait Simulation<M: Model>: Sized {
21 /// The error type returned if a simulation step fails.
22 ///
23 /// This type must implement [`From<M::Error>`] so errors produced by the model
24 /// (via [`Model::call`]) can be automatically converted using the `?` operator.
25 /// This requirement allows simulations to propagate model errors cleanly
26 /// when calling the model during a step or within [`advance_time`].
27 ///
28 /// Implementations may:
29 /// - Reuse the model's error type directly (`type StepError = M::Error`).
30 /// - Wrap it in a custom enum with additional error variants.
31 /// - Use boxed dynamic errors for maximum flexibility.
32 type StepError: std::error::Error + Send + Sync + 'static + From<M::Error>;
33
34 /// Provides a reference to the model being simulated.
35 fn model(&self) -> &M;
36
37 /// Computes the next input for the model, advancing the simulation in time.
38 ///
39 /// Given the current [`State`] and a proposed time step `dt`, this method
40 /// generates the next [`Model::Input`] to drive the simulation forward.
41 ///
42 /// This method is the primary customization point for incrementing time,
43 /// integrating state variables, enforcing constraints, applying control
44 /// logic, or incorporating external events.
45 /// It takes `&mut self` to support stateful integration algorithms such as
46 /// adaptive time stepping, multistep methods, or PID controllers that need
47 /// to record history.
48 ///
49 /// Implementations may interpret or adapt the proposed time step `dt` as
50 /// needed (e.g., for adaptive time stepping), and are free to update any
51 /// fields of the input required to continue the simulation.
52 ///
53 /// # Parameters
54 ///
55 /// - `state`: The current simulation state.
56 /// - `dt`: The proposed time step.
57 ///
58 /// # Returns
59 ///
60 /// The next input, computed from the current [`State`] and proposed `dt`.
61 ///
62 /// # Errors
63 ///
64 /// Returns a [`StepError`] if computing the next input fails.
65 fn advance_time(&mut self, state: &State<M>, dt: Duration)
66 -> Result<M::Input, Self::StepError>;
67
68 /// Advances the simulation by one step, starting from an initial input.
69 ///
70 /// This method first calls the model with the given input to compute
71 /// the initial output, forming a complete [`State`].
72 /// It then delegates to [`step_from_state`] to compute the next state.
73 /// As a result, the model is called twice: once to initialize the
74 /// state, and once after advancing.
75 ///
76 /// # Parameters
77 ///
78 /// - `input`: The model input at the start of the step.
79 /// - `dt`: The proposed time step.
80 ///
81 /// # Errors
82 ///
83 /// Returns a [`StepError`] if computing the next input or calling the model fails.
84 fn step(&mut self, input: M::Input, dt: Duration) -> Result<State<M>, Self::StepError> {
85 let output = self.model().call(&input)?;
86 let state = State::new(input, output);
87
88 self.step_from_state(&state, dt)
89 }
90
91 /// Advances the simulation by one step from a known [`State`].
92 ///
93 /// This method computes the next input using [`advance_time`],
94 /// then calls the model to produce the resulting [`State`].
95 ///
96 /// # Parameters
97 ///
98 /// - `state`: The current simulation state.
99 /// - `dt`: The proposed time step.
100 ///
101 /// # Errors
102 ///
103 /// Returns a [`StepError`] if computing the next input or calling the model fails.
104 fn step_from_state(
105 &mut self,
106 state: &State<M>,
107 dt: Duration,
108 ) -> Result<State<M>, Self::StepError> {
109 let input = self.advance_time(state, dt)?;
110 let output = self.model().call(&input)?;
111
112 Ok(State::new(input, output))
113 }
114
115 /// Runs the simulation for a fixed number of steps and collects the results.
116 ///
117 /// Starting from the given input, this method advances the simulation by
118 /// `steps` iterations using the proposed time step `dt`.
119 ///
120 /// # Parameters
121 ///
122 /// - `initial_input`: The model input at the start of the simulation.
123 /// - `steps`: The number of steps to run.
124 /// - `dt`: The proposed time step for each iteration.
125 ///
126 /// # Returns
127 ///
128 /// A `Vec` of length `steps + 1` containing each [`State`] computed during
129 /// the run, including the initial one.
130 ///
131 /// # Errors
132 ///
133 /// Returns a [`StepError`] if any step fails.
134 /// No further steps are taken after an error.
135 fn step_many(
136 &mut self,
137 initial_input: M::Input,
138 steps: usize,
139 dt: Duration,
140 ) -> Result<Vec<State<M>>, Self::StepError> {
141 let mut results = Vec::with_capacity(steps + 1);
142
143 let output = self.model().call(&initial_input)?;
144 results.push(State::new(initial_input, output));
145
146 for _ in 0..steps {
147 let last = results.last().expect("results not empty");
148 let next = self.step_from_state(last, dt)?;
149 results.push(next);
150 }
151
152 Ok(results)
153 }
154
155 /// Consumes the simulation and creates an iterator that advances it repeatedly.
156 ///
157 /// The iterator calls the simulation's stepping logic with a constant `dt`,
158 /// yielding each resulting [`State`] in sequence.
159 /// If a step fails, the error is returned and iteration stops.
160 ///
161 /// This method supports lazy or streaming evaluation and integrates cleanly
162 /// with iterator adapters such as `.take(n)`, `.map(...)`, or `.find(...)`.
163 /// It is memory-efficient and performs no intermediate allocations.
164 ///
165 /// # Parameters
166 ///
167 /// - `initial_input`: The model input at the start of the simulation.
168 /// - `dt`: The proposed time step for each iteration.
169 ///
170 /// # Returns
171 ///
172 /// An iterator over `Result<State<M>, StepError>` steps.
173 fn into_step_iter(
174 self,
175 initial_input: M::Input,
176 dt: Duration,
177 ) -> impl Iterator<Item = Result<State<M>, Self::StepError>>
178 where
179 M::Input: Clone,
180 M::Output: Clone,
181 {
182 StepIter {
183 dt,
184 known: Some(Known::Input(initial_input)),
185 sim: self,
186 }
187 }
188}
189
190/// Represents a snapshot of the simulation at a specific point in time.
191///
192/// A [`State`] pairs:
193/// - `input`: The independent variables, typically user-controlled or time-evolving.
194/// - `output`: The dependent variables, computed by the model.
195///
196/// Together, these describe the full state of the system at a given instant.
197#[derive(Debug, Default, PartialEq, PartialOrd)]
198pub struct State<M: Model> {
199 pub input: M::Input,
200 pub output: M::Output,
201}
202
203impl<M: Model> State<M> {
204 /// Creates a [`State`] from the provided input and output.
205 pub fn new(input: M::Input, output: M::Output) -> Self {
206 Self { input, output }
207 }
208}
209
210impl<M: Model> Clone for State<M>
211where
212 M::Input: Clone,
213 M::Output: Clone,
214{
215 fn clone(&self) -> Self {
216 Self {
217 input: self.input.clone(),
218 output: self.output.clone(),
219 }
220 }
221}
222
223impl<M: Model> Copy for State<M>
224where
225 M::Input: Copy,
226 M::Output: Copy,
227{
228}
229
230/// An iterator that repeatedly steps the simulation using a proposed time step.
231///
232/// Starting from an initial input, this iterator repeatedly steps the
233/// simulation using `dt`, yielding each resulting [`State`] as a `Result`.
234///
235/// If any step fails, the error is yielded and iteration stops.
236struct StepIter<M: Model, S: Simulation<M>> {
237 dt: Duration,
238 known: Option<Known<M>>,
239 sim: S,
240}
241
242/// Internal state held by the [`StepIter`] iterator.
243enum Known<M: Model> {
244 /// The simulation has only been initialized with an input.
245 Input(M::Input),
246 /// The full simulation state is available.
247 State(State<M>),
248}
249
250impl<M, S> Iterator for StepIter<M, S>
251where
252 M: Model,
253 S: Simulation<M>,
254 M::Input: Clone,
255 M::Output: Clone,
256{
257 type Item = Result<State<M>, S::StepError>;
258
259 /// Advances the simulation by one step.
260 ///
261 /// - If starting from an input, calls the model to produce the first state.
262 /// - If continuing from a full state, steps the simulation forward.
263 /// - On success, yields a new [`State`].
264 /// - On error, yields a [`StepError`] and ends the iteration.
265 fn next(&mut self) -> Option<Self::Item> {
266 let known = self.known.take()?;
267
268 match known {
269 // A full state exists - step forward from it.
270 Known::State(state) => match self.sim.step_from_state(&state, self.dt) {
271 Ok(state) => {
272 self.known = Some(Known::State(State::new(
273 state.input.clone(),
274 state.output.clone(),
275 )));
276 Some(Ok(state))
277 }
278 Err(error) => {
279 self.known = None;
280 Some(Err(error))
281 }
282 },
283
284 // Only the input is known - call the model and yield the first state.
285 Known::Input(input) => match self.sim.model().call(&input) {
286 Ok(output) => {
287 self.known = Some(Known::State(State::new(input.clone(), output.clone())));
288 let state = State::new(input, output);
289 Some(Ok(state))
290 }
291 Err(error) => {
292 self.known = None;
293 Some(Err(error.into()))
294 }
295 },
296 }
297 }
298}
299
300/// Marks that iteration always ends after the first `None`.
301impl<M, S> FusedIterator for StepIter<M, S>
302where
303 M: Model,
304 S: Simulation<M>,
305 M::Input: Clone,
306 M::Output: Clone,
307{
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 use std::convert::Infallible;
315
316 use approx::{assert_abs_diff_eq, assert_relative_eq};
317 use thiserror::Error;
318
319 /// A simple spring-damper model used for simulation tests.
320 #[derive(Debug)]
321 struct Spring {
322 spring_constant: f64,
323 damping_coef: f64,
324 }
325
326 #[derive(Debug, Clone, Default, PartialEq)]
327 struct Input {
328 time_in_minutes: f64,
329 position: f64,
330 velocity: f64,
331 }
332
333 #[derive(Debug, Clone, PartialEq)]
334 struct Output {
335 acceleration: f64,
336 }
337
338 impl Model for Spring {
339 type Input = Input;
340 type Output = Output;
341 type Error = Infallible;
342
343 fn call(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
344 let Input {
345 position, velocity, ..
346 } = input;
347
348 let acceleration = -self.spring_constant * position - self.damping_coef * velocity;
349
350 Ok(Output { acceleration })
351 }
352 }
353
354 #[derive(Debug)]
355 struct SpringSimulation {
356 model: Spring,
357 }
358
359 impl Simulation<Spring> for SpringSimulation {
360 type StepError = Infallible;
361
362 fn model(&self) -> &Spring {
363 &self.model
364 }
365
366 fn advance_time(
367 &mut self,
368 state: &State<Spring>,
369 dt: Duration,
370 ) -> Result<Input, Self::StepError> {
371 let seconds = dt.as_secs_f64();
372 let time_in_minutes = state.input.time_in_minutes + seconds / 60.0;
373
374 let position = state.input.position + state.input.velocity * seconds;
375 let velocity = state.input.velocity + state.output.acceleration * seconds;
376
377 Ok(Input {
378 time_in_minutes,
379 position,
380 velocity,
381 })
382 }
383 }
384
385 #[test]
386 fn zero_force_spring_has_constant_velocity() {
387 let mut sim = SpringSimulation {
388 model: Spring {
389 spring_constant: 0.0,
390 damping_coef: 0.0,
391 },
392 };
393
394 let initial = Input {
395 time_in_minutes: 0.0,
396 position: 10.0,
397 velocity: 2.0,
398 };
399
400 let steps = 3;
401 let dt = Duration::from_secs(30);
402
403 let states = sim.step_many(initial, steps, dt).unwrap();
404
405 let input_states: Vec<_> = states.iter().map(|s| s.input.clone()).collect();
406
407 assert_eq!(
408 input_states,
409 vec![
410 Input {
411 time_in_minutes: 0.0,
412 position: 10.0,
413 velocity: 2.0
414 },
415 Input {
416 time_in_minutes: 0.5,
417 position: 70.0,
418 velocity: 2.0
419 },
420 Input {
421 time_in_minutes: 1.0,
422 position: 130.0,
423 velocity: 2.0
424 },
425 Input {
426 time_in_minutes: 1.5,
427 position: 190.0,
428 velocity: 2.0
429 },
430 ]
431 );
432
433 assert!(
434 states.iter().all(|s| s.output.acceleration == 0.0),
435 "All accelerations should be zero"
436 );
437 }
438
439 #[test]
440 fn damped_spring_sim_converges_to_zero() {
441 let sim = SpringSimulation {
442 model: Spring {
443 spring_constant: 0.5,
444 damping_coef: 5.0,
445 },
446 };
447
448 let initial = Input {
449 position: 10.0,
450 ..Input::default()
451 };
452
453 let dt = Duration::from_millis(100);
454
455 let tolerance = 1e-4;
456 let max_steps = 5000;
457
458 let is_at_rest = |state: &State<Spring>| {
459 state.input.position.abs() < tolerance
460 && state.input.velocity.abs() < tolerance
461 && state.output.acceleration.abs() < tolerance
462 };
463
464 // Use the step iterator to find the first state close enough to zero.
465 let final_state = sim
466 .into_step_iter(initial, dt)
467 .take(max_steps)
468 .find_map(|res| match res {
469 Ok(state) if is_at_rest(&state) => Some(state),
470 Ok(_) => None,
471 Err(error) => panic!("Simulation error: {error:?}"),
472 })
473 .expect("Simulation did not reach a resting state within {max_steps} steps");
474
475 let State {
476 input: final_input,
477 output: final_output,
478 } = final_state;
479
480 assert_abs_diff_eq!(final_input.position, 0.0, epsilon = tolerance);
481 assert_abs_diff_eq!(final_input.velocity, 0.0, epsilon = tolerance);
482 assert_abs_diff_eq!(final_output.acceleration, 0.0, epsilon = tolerance);
483
484 assert_relative_eq!(final_input.time_in_minutes, 1.875, epsilon = tolerance);
485 }
486
487 /// A model that fails if the input exceeds a specified maximum.
488 #[derive(Debug)]
489 struct CheckInput {
490 max_value: usize,
491 }
492
493 impl Model for CheckInput {
494 type Input = usize;
495 type Output = ();
496 type Error = CheckInputError;
497
498 fn call(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
499 if *input <= self.max_value {
500 Ok(())
501 } else {
502 Err(CheckInputError(*input, self.max_value))
503 }
504 }
505 }
506
507 #[derive(Debug, Error)]
508 #[error("{0} is bigger than max value of {1}")]
509 struct CheckInputError(usize, usize);
510
511 /// A test simulation using [`CheckInput`].
512 ///
513 /// Each step increments the input by 1.
514 /// Yields an error when the input exceeds the maximum threshold `N`.
515 #[derive(Debug)]
516 struct CheckInputSim<const N: usize>;
517
518 impl<const N: usize> Simulation<CheckInput> for CheckInputSim<N> {
519 type StepError = CheckInputError;
520
521 fn model(&self) -> &CheckInput {
522 &CheckInput { max_value: N }
523 }
524
525 fn advance_time(
526 &mut self,
527 state: &State<CheckInput>,
528 _dt: Duration,
529 ) -> Result<usize, Self::StepError> {
530 Ok(state.input + 1)
531 }
532 }
533
534 #[test]
535 fn step_iter_yields_error_correctly() {
536 let mut iter = CheckInputSim::<3>.into_step_iter(0, Duration::from_secs(1));
537
538 let state = iter
539 .next()
540 .expect("Initial call yields a result")
541 .expect("Initial call is a success");
542 assert_eq!(state.input, 0);
543
544 let state = iter
545 .next()
546 .expect("First step yields a result")
547 .expect("First step is a success");
548 assert_eq!(state.input, 1);
549
550 let state = iter
551 .next()
552 .expect("Second step yields a result")
553 .expect("Second step is a success");
554 assert_eq!(state.input, 2);
555
556 let state = iter
557 .next()
558 .expect("Third step yields a result")
559 .expect("Third step is a success");
560 assert_eq!(state.input, 3);
561
562 let error = iter
563 .next()
564 .expect("Fourth step yields a result")
565 .expect_err("Fourth step is an error");
566 assert_eq!(format!("{error}"), "4 is bigger than max value of 3");
567
568 assert!(iter.next().is_none());
569 }
570}