twine_core/
simulation.rs

1use std::{iter::FusedIterator, time::Duration};
2
3/// Trait for defining a system model in Twine.
4///
5/// Models represent systems with clear boundaries defined by typed inputs and outputs.
6/// They must be deterministic, always producing the same result for a given input,
7/// which enables time evolution using a [`Simulation`].
8pub trait Model {
9    type Input;
10    type Output;
11    type Error: std::error::Error + Send + Sync + 'static;
12
13    /// Calls the model with the given input and returns a result.
14    ///
15    /// # Errors
16    ///
17    /// Each model defines its own `Error` type, allowing it to determine what
18    /// constitutes a failure within its domain.
19    fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
20}
21
22/// Trait for defining a transient simulation in Twine.
23///
24/// A `Simulation` advances a [`Model`] forward in time by computing its next
25/// input with [`advance_time`] and then calling the model to produce a new
26/// [`State`] representing the system at the corresponding future moment.
27///
28/// # Stepping Methods
29///
30/// After implementing [`advance_time`], the following methods are available for
31/// advancing the simulation:
32///
33/// - [`Simulation::step`]: Takes a single step from an initial input.
34/// - [`Simulation::step_from_state`]: Takes a single step from a known state.
35/// - [`Simulation::step_many`]: Takes multiple steps and collects all resulting states.
36/// - [`Simulation::into_step_iter`]: Consumes the simulation and returns an iterator over its steps.
37pub trait Simulation<M: Model>: Sized {
38    /// The error type returned if a simulation step fails.
39    ///
40    /// This type must implement [`From<M::Error>`] so errors produced by the model
41    /// (via [`Model::call`]) can be automatically converted using the `?` operator.
42    /// This requirement allows simulations to propagate model errors cleanly
43    /// when calling the model during a step or within [`advance_time`].
44    ///
45    /// Implementations may:
46    /// - Reuse the model's error type directly (`type StepError = M::Error`).
47    /// - Wrap it in a custom enum with additional error variants.
48    /// - Use boxed dynamic errors for maximum flexibility.
49    type StepError: std::error::Error + Send + Sync + 'static + From<M::Error>;
50
51    /// Provides a reference to the model being simulated.
52    fn model(&self) -> &M;
53
54    /// Computes the next input for the model, advancing the simulation in time.
55    ///
56    /// Given the current [`State`] and a proposed time step `dt`, this method
57    /// generates the next [`Model::Input`] to drive the simulation forward.
58    ///
59    /// This method is the primary customization point for incrementing time,
60    /// integrating state variables, enforcing constraints, applying control
61    /// logic, or incorporating external events.
62    /// It takes `&mut self` to support stateful integration algorithms such as
63    /// adaptive time stepping, multistep methods, or PID controllers that need
64    /// to record history.
65    ///
66    /// Implementations may interpret or adapt the proposed time step `dt` as
67    /// needed (e.g., for adaptive time stepping), and are free to update any
68    /// fields of the input required to continue the simulation.
69    ///
70    /// # Parameters
71    ///
72    /// - `state`: The current simulation state.
73    /// - `dt`: The proposed time step.
74    ///
75    /// # Returns
76    ///
77    /// The next input, computed from the current [`State`] and proposed `dt`.
78    ///
79    /// # Errors
80    ///
81    /// Returns a [`StepError`] if computing the next input fails.
82    fn advance_time(&mut self, state: &State<M>, dt: Duration)
83    -> Result<M::Input, Self::StepError>;
84
85    /// Advances the simulation by one step, starting from an initial input.
86    ///
87    /// This method first calls the model with the given input to compute
88    /// the initial output, forming a complete [`State`].
89    /// It then delegates to [`step_from_state`] to compute the next state.
90    /// As a result, the model is called twice: once to initialize the
91    /// state, and once after advancing.
92    ///
93    /// # Parameters
94    ///
95    /// - `input`: The model input at the start of the step.
96    /// - `dt`: The proposed time step.
97    ///
98    /// # Errors
99    ///
100    /// Returns a [`StepError`] if computing the next input or calling the model fails.
101    fn step(&mut self, input: M::Input, dt: Duration) -> Result<State<M>, Self::StepError>
102    where
103        M::Input: Clone,
104    {
105        let output = self.model().call(input.clone())?;
106        let state = State::new(input, output);
107
108        self.step_from_state(&state, dt)
109    }
110
111    /// Advances the simulation by one step from a known [`State`].
112    ///
113    /// This method computes the next input using [`advance_time`],
114    /// then calls the model to produce the resulting [`State`].
115    ///
116    /// # Parameters
117    ///
118    /// - `state`: The current simulation state.
119    /// - `dt`: The proposed time step.
120    ///
121    /// # Errors
122    ///
123    /// Returns a [`StepError`] if computing the next input or calling the model fails.
124    fn step_from_state(
125        &mut self,
126        state: &State<M>,
127        dt: Duration,
128    ) -> Result<State<M>, Self::StepError>
129    where
130        M::Input: Clone,
131    {
132        let input = self.advance_time(state, dt)?;
133        let output = self.model().call(input.clone())?;
134
135        Ok(State::new(input, output))
136    }
137
138    /// Runs the simulation for a fixed number of steps and collects the results.
139    ///
140    /// Starting from the given input, this method advances the simulation by
141    /// `steps` iterations using the proposed time step `dt`.
142    ///
143    /// # Parameters
144    ///
145    /// - `initial_input`: The model input at the start of the simulation.
146    /// - `steps`: The number of steps to run.
147    /// - `dt`: The proposed time step for each iteration.
148    ///
149    /// # Returns
150    ///
151    /// A `Vec` of length `steps + 1` containing each [`State`] computed during
152    /// the run, including the initial one.
153    ///
154    /// # Errors
155    ///
156    /// Returns a [`StepError`] if any step fails.
157    /// No further steps are taken after an error.
158    fn step_many(
159        &mut self,
160        initial_input: M::Input,
161        steps: usize,
162        dt: Duration,
163    ) -> Result<Vec<State<M>>, Self::StepError>
164    where
165        M::Input: Clone,
166        M::Output: Clone,
167    {
168        let mut results = Vec::with_capacity(steps + 1);
169
170        let output = self.model().call(initial_input.clone())?;
171        let mut current_state = State::new(initial_input, output);
172        results.push(current_state.clone());
173
174        for _ in 0..steps {
175            current_state = self.step_from_state(&current_state, dt)?;
176            results.push(current_state.clone());
177        }
178
179        Ok(results)
180    }
181
182    /// Consumes the simulation and creates an iterator that advances it repeatedly.
183    ///
184    /// The iterator calls the simulation's stepping logic with a constant `dt`,
185    /// yielding each resulting [`State`] in sequence.
186    /// If a step fails, the error is returned and iteration stops.
187    ///
188    /// This method supports lazy or streaming evaluation and integrates cleanly
189    /// with iterator adapters such as `.take(n)`, `.map(...)`, or `.find(...)`.
190    /// It is memory-efficient and performs no intermediate allocations.
191    ///
192    /// # Parameters
193    ///
194    /// - `initial_input`: The model input at the start of the simulation.
195    /// - `dt`: The proposed time step for each iteration.
196    ///
197    /// # Returns
198    ///
199    /// An iterator over `Result<State<M>, StepError>` steps.
200    fn into_step_iter(
201        self,
202        initial_input: M::Input,
203        dt: Duration,
204    ) -> impl Iterator<Item = Result<State<M>, Self::StepError>>
205    where
206        M::Input: Clone,
207        M::Output: Clone,
208    {
209        StepIter {
210            dt,
211            known: Some(Known::Input(initial_input)),
212            sim: self,
213        }
214    }
215}
216
217/// Represents a snapshot of the simulation at a specific point in time.
218///
219/// A [`State`] pairs:
220/// - `input`: The independent variables, typically user-controlled or time-evolving.
221/// - `output`: The dependent variables, computed by the model.
222///
223/// Together, these describe the full state of the system at a given instant.
224#[derive(Debug, Default, PartialEq, PartialOrd)]
225pub struct State<M: Model> {
226    pub input: M::Input,
227    pub output: M::Output,
228}
229
230impl<M: Model> State<M> {
231    /// Creates a [`State`] from the provided input and output.
232    pub fn new(input: M::Input, output: M::Output) -> Self {
233        Self { input, output }
234    }
235}
236
237impl<M: Model> Clone for State<M>
238where
239    M::Input: Clone,
240    M::Output: Clone,
241{
242    fn clone(&self) -> Self {
243        Self {
244            input: self.input.clone(),
245            output: self.output.clone(),
246        }
247    }
248}
249
250impl<M: Model> Copy for State<M>
251where
252    M::Input: Copy,
253    M::Output: Copy,
254{
255}
256
257/// An iterator that repeatedly steps the simulation using a proposed time step.
258///
259/// Starting from an initial input, this iterator repeatedly steps the
260/// simulation using `dt`, yielding each resulting [`State`] as a `Result`.
261///
262/// If any step fails, the error is yielded and iteration stops.
263struct StepIter<M: Model, S: Simulation<M>> {
264    dt: Duration,
265    known: Option<Known<M>>,
266    sim: S,
267}
268
269/// Internal state held by the [`StepIter`] iterator.
270enum Known<M: Model> {
271    /// The simulation has only been initialized with an input.
272    Input(M::Input),
273    /// The full simulation state is available.
274    State(State<M>),
275}
276
277impl<M, S> Iterator for StepIter<M, S>
278where
279    M: Model,
280    S: Simulation<M>,
281    M::Input: Clone,
282    M::Output: Clone,
283{
284    type Item = Result<State<M>, S::StepError>;
285
286    /// Advances the simulation by one step.
287    ///
288    /// - If starting from an input, calls the model to produce the first state.
289    /// - If continuing from a full state, steps the simulation forward.
290    /// - On success, yields a new [`State`].
291    /// - On error, yields a [`StepError`] and ends the iteration.
292    fn next(&mut self) -> Option<Self::Item> {
293        let known = self.known.take()?;
294
295        match known {
296            // A full state exists - step forward from it.
297            Known::State(state) => match self.sim.step_from_state(&state, self.dt) {
298                Ok(state) => {
299                    self.known = Some(Known::State(State::new(
300                        state.input.clone(),
301                        state.output.clone(),
302                    )));
303                    Some(Ok(state))
304                }
305                Err(error) => {
306                    self.known = None;
307                    Some(Err(error))
308                }
309            },
310
311            // Only the input is known - call the model and yield the first state.
312            Known::Input(input) => match self.sim.model().call(input.clone()) {
313                Ok(output) => {
314                    self.known = Some(Known::State(State::new(input.clone(), output.clone())));
315                    let state = State::new(input, output);
316                    Some(Ok(state))
317                }
318                Err(error) => {
319                    self.known = None;
320                    Some(Err(error.into()))
321                }
322            },
323        }
324    }
325}
326
327/// Marks that iteration always ends after the first `None`.
328impl<M, S> FusedIterator for StepIter<M, S>
329where
330    M: Model,
331    S: Simulation<M>,
332    M::Input: Clone,
333    M::Output: Clone,
334{
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    use std::convert::Infallible;
342
343    use approx::{assert_abs_diff_eq, assert_relative_eq};
344    use thiserror::Error;
345
346    /// A simple spring-damper model used for simulation tests.
347    #[derive(Debug)]
348    struct Spring {
349        spring_constant: f64,
350        damping_coef: f64,
351    }
352
353    #[derive(Debug, Clone, Default, PartialEq)]
354    struct Input {
355        time_in_minutes: f64,
356        position: f64,
357        velocity: f64,
358    }
359
360    #[derive(Debug, Clone, PartialEq)]
361    struct Output {
362        acceleration: f64,
363    }
364
365    impl Model for Spring {
366        type Input = Input;
367        type Output = Output;
368        type Error = Infallible;
369
370        fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
371            let Input {
372                position, velocity, ..
373            } = input;
374
375            let acceleration = -self.spring_constant * position - self.damping_coef * velocity;
376
377            Ok(Output { acceleration })
378        }
379    }
380
381    #[derive(Debug)]
382    struct SpringSimulation {
383        model: Spring,
384    }
385
386    impl Simulation<Spring> for SpringSimulation {
387        type StepError = Infallible;
388
389        fn model(&self) -> &Spring {
390            &self.model
391        }
392
393        fn advance_time(
394            &mut self,
395            state: &State<Spring>,
396            dt: Duration,
397        ) -> Result<Input, Self::StepError> {
398            let seconds = dt.as_secs_f64();
399            let time_in_minutes = state.input.time_in_minutes + seconds / 60.0;
400
401            let position = state.input.position + state.input.velocity * seconds;
402            let velocity = state.input.velocity + state.output.acceleration * seconds;
403
404            Ok(Input {
405                time_in_minutes,
406                position,
407                velocity,
408            })
409        }
410    }
411
412    #[test]
413    fn zero_force_spring_has_constant_velocity() {
414        let mut sim = SpringSimulation {
415            model: Spring {
416                spring_constant: 0.0,
417                damping_coef: 0.0,
418            },
419        };
420
421        let initial = Input {
422            time_in_minutes: 0.0,
423            position: 10.0,
424            velocity: 2.0,
425        };
426
427        let steps = 3;
428        let dt = Duration::from_secs(30);
429
430        let states = sim.step_many(initial, steps, dt).unwrap();
431
432        let input_states: Vec<_> = states.iter().map(|s| s.input.clone()).collect();
433
434        assert_eq!(
435            input_states,
436            vec![
437                Input {
438                    time_in_minutes: 0.0,
439                    position: 10.0,
440                    velocity: 2.0
441                },
442                Input {
443                    time_in_minutes: 0.5,
444                    position: 70.0,
445                    velocity: 2.0
446                },
447                Input {
448                    time_in_minutes: 1.0,
449                    position: 130.0,
450                    velocity: 2.0
451                },
452                Input {
453                    time_in_minutes: 1.5,
454                    position: 190.0,
455                    velocity: 2.0
456                },
457            ]
458        );
459
460        assert!(
461            states.iter().all(|s| s.output.acceleration == 0.0),
462            "All accelerations should be zero"
463        );
464    }
465
466    #[test]
467    fn damped_spring_sim_converges_to_zero() {
468        let sim = SpringSimulation {
469            model: Spring {
470                spring_constant: 0.5,
471                damping_coef: 5.0,
472            },
473        };
474
475        let initial = Input {
476            position: 10.0,
477            ..Input::default()
478        };
479
480        let dt = Duration::from_millis(100);
481
482        let tolerance = 1e-4;
483        let max_steps = 5000;
484
485        let is_at_rest = |state: &State<Spring>| {
486            state.input.position.abs() < tolerance
487                && state.input.velocity.abs() < tolerance
488                && state.output.acceleration.abs() < tolerance
489        };
490
491        // Use the step iterator to find the first state close enough to zero.
492        let final_state = sim
493            .into_step_iter(initial, dt)
494            .take(max_steps)
495            .find_map(|res| match res {
496                Ok(state) if is_at_rest(&state) => Some(state),
497                Ok(_) => None,
498                Err(error) => panic!("Simulation error: {error:?}"),
499            })
500            .expect("Simulation did not reach a resting state within {max_steps} steps");
501
502        let State {
503            input: final_input,
504            output: final_output,
505        } = final_state;
506
507        assert_abs_diff_eq!(final_input.position, 0.0, epsilon = tolerance);
508        assert_abs_diff_eq!(final_input.velocity, 0.0, epsilon = tolerance);
509        assert_abs_diff_eq!(final_output.acceleration, 0.0, epsilon = tolerance);
510
511        assert_relative_eq!(final_input.time_in_minutes, 1.875, epsilon = tolerance);
512    }
513
514    /// A model that fails if the input exceeds a specified maximum.
515    #[derive(Debug)]
516    struct CheckInput {
517        max_value: usize,
518    }
519
520    impl Model for CheckInput {
521        type Input = usize;
522        type Output = ();
523        type Error = CheckInputError;
524
525        fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
526            if input <= self.max_value {
527                Ok(())
528            } else {
529                Err(CheckInputError(input, self.max_value))
530            }
531        }
532    }
533
534    #[derive(Debug, Error)]
535    #[error("{0} is bigger than max value of {1}")]
536    struct CheckInputError(usize, usize);
537
538    /// A test simulation using [`CheckInput`].
539    ///
540    /// Each step increments the input by 1.
541    /// Yields an error when the input exceeds the maximum threshold `N`.
542    #[derive(Debug)]
543    struct CheckInputSim<const N: usize>;
544
545    impl<const N: usize> Simulation<CheckInput> for CheckInputSim<N> {
546        type StepError = CheckInputError;
547
548        fn model(&self) -> &CheckInput {
549            &CheckInput { max_value: N }
550        }
551
552        fn advance_time(
553            &mut self,
554            state: &State<CheckInput>,
555            _dt: Duration,
556        ) -> Result<usize, Self::StepError> {
557            Ok(state.input + 1)
558        }
559    }
560
561    #[test]
562    fn step_iter_yields_error_correctly() {
563        let mut iter = CheckInputSim::<3>.into_step_iter(0, Duration::from_secs(1));
564
565        let state = iter
566            .next()
567            .expect("Initial call yields a result")
568            .expect("Initial call is a success");
569        assert_eq!(state.input, 0);
570
571        let state = iter
572            .next()
573            .expect("First step yields a result")
574            .expect("First step is a success");
575        assert_eq!(state.input, 1);
576
577        let state = iter
578            .next()
579            .expect("Second step yields a result")
580            .expect("Second step is a success");
581        assert_eq!(state.input, 2);
582
583        let state = iter
584            .next()
585            .expect("Third step yields a result")
586            .expect("Third step is a success");
587        assert_eq!(state.input, 3);
588
589        let error = iter
590            .next()
591            .expect("Fourth step yields a result")
592            .expect_err("Fourth step is an error");
593        assert_eq!(format!("{error}"), "4 is bigger than max value of 3");
594
595        assert!(iter.next().is_none());
596    }
597}