twine_core/
simulation.rs

1use std::{iter::FusedIterator, time::Duration};
2
3use crate::model::Model;
4
5/// Trait for defining a transient simulation in Twine.
6///
7/// A `Simulation` advances a [`Model`] forward in time by computing its next
8/// input with [`advance_time`] and then calling the model to produce a new
9/// [`State`] representing the system at the corresponding future moment.
10///
11/// # Stepping Methods
12///
13/// After implementing [`advance_time`], the following methods are available for
14/// advancing the simulation:
15///
16/// - [`Simulation::step`]: Takes a single step from an initial input.
17/// - [`Simulation::step_from_state`]: Takes a single step from a known state.
18/// - [`Simulation::step_many`]: Takes multiple steps and collects all resulting states.
19/// - [`Simulation::into_step_iter`]: Consumes the simulation and returns an iterator over its steps.
20pub trait Simulation<M: Model>: Sized {
21    /// The error type returned if a simulation step fails.
22    ///
23    /// This type must implement [`From<M::Error>`] so errors produced by the model
24    /// (via [`Model::call`]) can be automatically converted using the `?` operator.
25    /// This requirement allows simulations to propagate model errors cleanly
26    /// when calling the model during a step or within [`advance_time`].
27    ///
28    /// Implementations may:
29    /// - Reuse the model's error type directly (`type StepError = M::Error`).
30    /// - Wrap it in a custom enum with additional error variants.
31    /// - Use boxed dynamic errors for maximum flexibility.
32    type StepError: std::error::Error + Send + Sync + 'static + From<M::Error>;
33
34    /// Provides a reference to the model being simulated.
35    fn model(&self) -> &M;
36
37    /// Computes the next input for the model, advancing the simulation in time.
38    ///
39    /// Given the current [`State`] and a proposed time step `dt`, this method
40    /// generates the next [`Model::Input`] to drive the simulation forward.
41    ///
42    /// This method is the primary customization point for incrementing time,
43    /// integrating state variables, enforcing constraints, applying control
44    /// logic, or incorporating external events.
45    /// It takes `&mut self` to support stateful integration algorithms such as
46    /// adaptive time stepping, multistep methods, or PID controllers that need
47    /// to record history.
48    ///
49    /// Implementations may interpret or adapt the proposed time step `dt` as
50    /// needed (e.g., for adaptive time stepping), and are free to update any
51    /// fields of the input required to continue the simulation.
52    ///
53    /// # Parameters
54    ///
55    /// - `state`: The current simulation state.
56    /// - `dt`: The proposed time step.
57    ///
58    /// # Returns
59    ///
60    /// The next input, computed from the current [`State`] and proposed `dt`.
61    ///
62    /// # Errors
63    ///
64    /// Returns a [`StepError`] if computing the next input fails.
65    fn advance_time(&mut self, state: &State<M>, dt: Duration)
66    -> Result<M::Input, Self::StepError>;
67
68    /// Advances the simulation by one step, starting from an initial input.
69    ///
70    /// This method first calls the model with the given input to compute
71    /// the initial output, forming a complete [`State`].
72    /// It then delegates to [`step_from_state`] to compute the next state.
73    /// As a result, the model is called twice: once to initialize the
74    /// state, and once after advancing.
75    ///
76    /// # Parameters
77    ///
78    /// - `input`: The model input at the start of the step.
79    /// - `dt`: The proposed time step.
80    ///
81    /// # Errors
82    ///
83    /// Returns a [`StepError`] if computing the next input or calling the model fails.
84    fn step(&mut self, input: M::Input, dt: Duration) -> Result<State<M>, Self::StepError> {
85        let output = self.model().call(&input)?;
86        let state = State::new(input, output);
87
88        self.step_from_state(&state, dt)
89    }
90
91    /// Advances the simulation by one step from a known [`State`].
92    ///
93    /// This method computes the next input using [`advance_time`],
94    /// then calls the model to produce the resulting [`State`].
95    ///
96    /// # Parameters
97    ///
98    /// - `state`: The current simulation state.
99    /// - `dt`: The proposed time step.
100    ///
101    /// # Errors
102    ///
103    /// Returns a [`StepError`] if computing the next input or calling the model fails.
104    fn step_from_state(
105        &mut self,
106        state: &State<M>,
107        dt: Duration,
108    ) -> Result<State<M>, Self::StepError> {
109        let input = self.advance_time(state, dt)?;
110        let output = self.model().call(&input)?;
111
112        Ok(State::new(input, output))
113    }
114
115    /// Runs the simulation for a fixed number of steps and collects the results.
116    ///
117    /// Starting from the given input, this method advances the simulation by
118    /// `steps` iterations using the proposed time step `dt`.
119    ///
120    /// # Parameters
121    ///
122    /// - `initial_input`: The model input at the start of the simulation.
123    /// - `steps`: The number of steps to run.
124    /// - `dt`: The proposed time step for each iteration.
125    ///
126    /// # Returns
127    ///
128    /// A `Vec` of length `steps + 1` containing each [`State`] computed during
129    /// the run, including the initial one.
130    ///
131    /// # Errors
132    ///
133    /// Returns a [`StepError`] if any step fails.
134    /// No further steps are taken after an error.
135    fn step_many(
136        &mut self,
137        initial_input: M::Input,
138        steps: usize,
139        dt: Duration,
140    ) -> Result<Vec<State<M>>, Self::StepError> {
141        let mut results = Vec::with_capacity(steps + 1);
142
143        let output = self.model().call(&initial_input)?;
144        results.push(State::new(initial_input, output));
145
146        for _ in 0..steps {
147            let last = results.last().expect("results not empty");
148            let next = self.step_from_state(last, dt)?;
149            results.push(next);
150        }
151
152        Ok(results)
153    }
154
155    /// Consumes the simulation and creates an iterator that advances it repeatedly.
156    ///
157    /// The iterator calls the simulation's stepping logic with a constant `dt`,
158    /// yielding each resulting [`State`] in sequence.
159    /// If a step fails, the error is returned and iteration stops.
160    ///
161    /// This method supports lazy or streaming evaluation and integrates cleanly
162    /// with iterator adapters such as `.take(n)`, `.map(...)`, or `.find(...)`.
163    /// It is memory-efficient and performs no intermediate allocations.
164    ///
165    /// # Parameters
166    ///
167    /// - `initial_input`: The model input at the start of the simulation.
168    /// - `dt`: The proposed time step for each iteration.
169    ///
170    /// # Returns
171    ///
172    /// An iterator over `Result<State<M>, StepError>` steps.
173    fn into_step_iter(
174        self,
175        initial_input: M::Input,
176        dt: Duration,
177    ) -> impl Iterator<Item = Result<State<M>, Self::StepError>>
178    where
179        M::Input: Clone,
180        M::Output: Clone,
181    {
182        StepIter {
183            dt,
184            known: Some(Known::Input(initial_input)),
185            sim: self,
186        }
187    }
188}
189
190/// Represents a snapshot of the simulation at a specific point in time.
191///
192/// A [`State`] pairs:
193/// - `input`: The independent variables, typically user-controlled or time-evolving.
194/// - `output`: The dependent variables, computed by the model.
195///
196/// Together, these describe the full state of the system at a given instant.
197#[derive(Debug, Default, PartialEq, PartialOrd)]
198pub struct State<M: Model> {
199    pub input: M::Input,
200    pub output: M::Output,
201}
202
203impl<M: Model> State<M> {
204    /// Creates a [`State`] from the provided input and output.
205    pub fn new(input: M::Input, output: M::Output) -> Self {
206        Self { input, output }
207    }
208}
209
210impl<M: Model> Clone for State<M>
211where
212    M::Input: Clone,
213    M::Output: Clone,
214{
215    fn clone(&self) -> Self {
216        Self {
217            input: self.input.clone(),
218            output: self.output.clone(),
219        }
220    }
221}
222
223impl<M: Model> Copy for State<M>
224where
225    M::Input: Copy,
226    M::Output: Copy,
227{
228}
229
230/// An iterator that repeatedly steps the simulation using a proposed time step.
231///
232/// Starting from an initial input, this iterator repeatedly steps the
233/// simulation using `dt`, yielding each resulting [`State`] as a `Result`.
234///
235/// If any step fails, the error is yielded and iteration stops.
236struct StepIter<M: Model, S: Simulation<M>> {
237    dt: Duration,
238    known: Option<Known<M>>,
239    sim: S,
240}
241
242/// Internal state held by the [`StepIter`] iterator.
243enum Known<M: Model> {
244    /// The simulation has only been initialized with an input.
245    Input(M::Input),
246    /// The full simulation state is available.
247    State(State<M>),
248}
249
250impl<M, S> Iterator for StepIter<M, S>
251where
252    M: Model,
253    S: Simulation<M>,
254    M::Input: Clone,
255    M::Output: Clone,
256{
257    type Item = Result<State<M>, S::StepError>;
258
259    /// Advances the simulation by one step.
260    ///
261    /// - If starting from an input, calls the model to produce the first state.
262    /// - If continuing from a full state, steps the simulation forward.
263    /// - On success, yields a new [`State`].
264    /// - On error, yields a [`StepError`] and ends the iteration.
265    fn next(&mut self) -> Option<Self::Item> {
266        let known = self.known.take()?;
267
268        match known {
269            // A full state exists - step forward from it.
270            Known::State(state) => match self.sim.step_from_state(&state, self.dt) {
271                Ok(state) => {
272                    self.known = Some(Known::State(State::new(
273                        state.input.clone(),
274                        state.output.clone(),
275                    )));
276                    Some(Ok(state))
277                }
278                Err(error) => {
279                    self.known = None;
280                    Some(Err(error))
281                }
282            },
283
284            // Only the input is known - call the model and yield the first state.
285            Known::Input(input) => match self.sim.model().call(&input) {
286                Ok(output) => {
287                    self.known = Some(Known::State(State::new(input.clone(), output.clone())));
288                    let state = State::new(input, output);
289                    Some(Ok(state))
290                }
291                Err(error) => {
292                    self.known = None;
293                    Some(Err(error.into()))
294                }
295            },
296        }
297    }
298}
299
300/// Marks that iteration always ends after the first `None`.
301impl<M, S> FusedIterator for StepIter<M, S>
302where
303    M: Model,
304    S: Simulation<M>,
305    M::Input: Clone,
306    M::Output: Clone,
307{
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    use std::convert::Infallible;
315
316    use approx::{assert_abs_diff_eq, assert_relative_eq};
317    use thiserror::Error;
318
319    /// A simple spring-damper model used for simulation tests.
320    #[derive(Debug)]
321    struct Spring {
322        spring_constant: f64,
323        damping_coef: f64,
324    }
325
326    #[derive(Debug, Clone, Default, PartialEq)]
327    struct Input {
328        time_in_minutes: f64,
329        position: f64,
330        velocity: f64,
331    }
332
333    #[derive(Debug, Clone, PartialEq)]
334    struct Output {
335        acceleration: f64,
336    }
337
338    impl Model for Spring {
339        type Input = Input;
340        type Output = Output;
341        type Error = Infallible;
342
343        fn call(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
344            let Input {
345                position, velocity, ..
346            } = input;
347
348            let acceleration = -self.spring_constant * position - self.damping_coef * velocity;
349
350            Ok(Output { acceleration })
351        }
352    }
353
354    #[derive(Debug)]
355    struct SpringSimulation {
356        model: Spring,
357    }
358
359    impl Simulation<Spring> for SpringSimulation {
360        type StepError = Infallible;
361
362        fn model(&self) -> &Spring {
363            &self.model
364        }
365
366        fn advance_time(
367            &mut self,
368            state: &State<Spring>,
369            dt: Duration,
370        ) -> Result<Input, Self::StepError> {
371            let seconds = dt.as_secs_f64();
372            let time_in_minutes = state.input.time_in_minutes + seconds / 60.0;
373
374            let position = state.input.position + state.input.velocity * seconds;
375            let velocity = state.input.velocity + state.output.acceleration * seconds;
376
377            Ok(Input {
378                time_in_minutes,
379                position,
380                velocity,
381            })
382        }
383    }
384
385    #[test]
386    fn zero_force_spring_has_constant_velocity() {
387        let mut sim = SpringSimulation {
388            model: Spring {
389                spring_constant: 0.0,
390                damping_coef: 0.0,
391            },
392        };
393
394        let initial = Input {
395            time_in_minutes: 0.0,
396            position: 10.0,
397            velocity: 2.0,
398        };
399
400        let steps = 3;
401        let dt = Duration::from_secs(30);
402
403        let states = sim.step_many(initial, steps, dt).unwrap();
404
405        let input_states: Vec<_> = states.iter().map(|s| s.input.clone()).collect();
406
407        assert_eq!(
408            input_states,
409            vec![
410                Input {
411                    time_in_minutes: 0.0,
412                    position: 10.0,
413                    velocity: 2.0
414                },
415                Input {
416                    time_in_minutes: 0.5,
417                    position: 70.0,
418                    velocity: 2.0
419                },
420                Input {
421                    time_in_minutes: 1.0,
422                    position: 130.0,
423                    velocity: 2.0
424                },
425                Input {
426                    time_in_minutes: 1.5,
427                    position: 190.0,
428                    velocity: 2.0
429                },
430            ]
431        );
432
433        assert!(
434            states.iter().all(|s| s.output.acceleration == 0.0),
435            "All accelerations should be zero"
436        );
437    }
438
439    #[test]
440    fn damped_spring_sim_converges_to_zero() {
441        let sim = SpringSimulation {
442            model: Spring {
443                spring_constant: 0.5,
444                damping_coef: 5.0,
445            },
446        };
447
448        let initial = Input {
449            position: 10.0,
450            ..Input::default()
451        };
452
453        let dt = Duration::from_millis(100);
454
455        let tolerance = 1e-4;
456        let max_steps = 5000;
457
458        let is_at_rest = |state: &State<Spring>| {
459            state.input.position.abs() < tolerance
460                && state.input.velocity.abs() < tolerance
461                && state.output.acceleration.abs() < tolerance
462        };
463
464        // Use the step iterator to find the first state close enough to zero.
465        let final_state = sim
466            .into_step_iter(initial, dt)
467            .take(max_steps)
468            .find_map(|res| match res {
469                Ok(state) if is_at_rest(&state) => Some(state),
470                Ok(_) => None,
471                Err(error) => panic!("Simulation error: {error:?}"),
472            })
473            .expect("Simulation did not reach a resting state within {max_steps} steps");
474
475        let State {
476            input: final_input,
477            output: final_output,
478        } = final_state;
479
480        assert_abs_diff_eq!(final_input.position, 0.0, epsilon = tolerance);
481        assert_abs_diff_eq!(final_input.velocity, 0.0, epsilon = tolerance);
482        assert_abs_diff_eq!(final_output.acceleration, 0.0, epsilon = tolerance);
483
484        assert_relative_eq!(final_input.time_in_minutes, 1.875, epsilon = tolerance);
485    }
486
487    /// A model that fails if the input exceeds a specified maximum.
488    #[derive(Debug)]
489    struct CheckInput {
490        max_value: usize,
491    }
492
493    impl Model for CheckInput {
494        type Input = usize;
495        type Output = ();
496        type Error = CheckInputError;
497
498        fn call(&self, input: &Self::Input) -> Result<Self::Output, Self::Error> {
499            if *input <= self.max_value {
500                Ok(())
501            } else {
502                Err(CheckInputError(*input, self.max_value))
503            }
504        }
505    }
506
507    #[derive(Debug, Error)]
508    #[error("{0} is bigger than max value of {1}")]
509    struct CheckInputError(usize, usize);
510
511    /// A test simulation using [`CheckInput`].
512    ///
513    /// Each step increments the input by 1.
514    /// Yields an error when the input exceeds the maximum threshold `N`.
515    #[derive(Debug)]
516    struct CheckInputSim<const N: usize>;
517
518    impl<const N: usize> Simulation<CheckInput> for CheckInputSim<N> {
519        type StepError = CheckInputError;
520
521        fn model(&self) -> &CheckInput {
522            &CheckInput { max_value: N }
523        }
524
525        fn advance_time(
526            &mut self,
527            state: &State<CheckInput>,
528            _dt: Duration,
529        ) -> Result<usize, Self::StepError> {
530            Ok(state.input + 1)
531        }
532    }
533
534    #[test]
535    fn step_iter_yields_error_correctly() {
536        let mut iter = CheckInputSim::<3>.into_step_iter(0, Duration::from_secs(1));
537
538        let state = iter
539            .next()
540            .expect("Initial call yields a result")
541            .expect("Initial call is a success");
542        assert_eq!(state.input, 0);
543
544        let state = iter
545            .next()
546            .expect("First step yields a result")
547            .expect("First step is a success");
548        assert_eq!(state.input, 1);
549
550        let state = iter
551            .next()
552            .expect("Second step yields a result")
553            .expect("Second step is a success");
554        assert_eq!(state.input, 2);
555
556        let state = iter
557            .next()
558            .expect("Third step yields a result")
559            .expect("Third step is a success");
560        assert_eq!(state.input, 3);
561
562        let error = iter
563            .next()
564            .expect("Fourth step yields a result")
565            .expect_err("Fourth step is an error");
566        assert_eq!(format!("{error}"), "4 is bigger than max value of 3");
567
568        assert!(iter.next().is_none());
569    }
570}