twine_core/simulation.rs
1use std::{iter::FusedIterator, time::Duration};
2
3/// Trait for defining a system model in Twine.
4///
5/// Models represent systems with clear boundaries defined by typed inputs and outputs.
6/// They must be deterministic, always producing the same result for a given input,
7/// which enables time evolution using a [`Simulation`].
8pub trait Model {
9 type Input;
10 type Output;
11 type Error: std::error::Error + Send + Sync + 'static;
12
13 /// Calls the model with the given input and returns a result.
14 ///
15 /// # Errors
16 ///
17 /// Each model defines its own `Error` type, allowing it to determine what
18 /// constitutes a failure within its domain.
19 fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
20}
21
22/// Trait for defining a transient simulation in Twine.
23///
24/// A `Simulation` advances a [`Model`] forward in time by computing its next
25/// input with [`advance_time`] and then calling the model to produce a new
26/// [`State`] representing the system at the corresponding future moment.
27///
28/// # Stepping Methods
29///
30/// After implementing [`advance_time`], the following methods are available for
31/// advancing the simulation:
32///
33/// - [`Simulation::step`]: Takes a single step from an initial input.
34/// - [`Simulation::step_from_state`]: Takes a single step from a known state.
35/// - [`Simulation::step_many`]: Takes multiple steps and collects all resulting states.
36/// - [`Simulation::into_step_iter`]: Consumes the simulation and returns an iterator over its steps.
37pub trait Simulation<M: Model>: Sized {
38 /// The error type returned if a simulation step fails.
39 ///
40 /// This type must implement [`From<M::Error>`] so errors produced by the model
41 /// (via [`Model::call`]) can be automatically converted using the `?` operator.
42 /// This requirement allows simulations to propagate model errors cleanly
43 /// when calling the model during a step or within [`advance_time`].
44 ///
45 /// Implementations may:
46 /// - Reuse the model's error type directly (`type StepError = M::Error`).
47 /// - Wrap it in a custom enum with additional error variants.
48 /// - Use boxed dynamic errors for maximum flexibility.
49 type StepError: std::error::Error + Send + Sync + 'static + From<M::Error>;
50
51 /// Provides a reference to the model being simulated.
52 fn model(&self) -> &M;
53
54 /// Computes the next input for the model, advancing the simulation in time.
55 ///
56 /// Given the current [`State`] and a proposed time step `dt`, this method
57 /// generates the next [`Model::Input`] to drive the simulation forward.
58 ///
59 /// This method is the primary customization point for incrementing time,
60 /// integrating state variables, enforcing constraints, applying control
61 /// logic, or incorporating external events.
62 /// It takes `&mut self` to support stateful integration algorithms such as
63 /// adaptive time stepping, multistep methods, or PID controllers that need
64 /// to record history.
65 ///
66 /// Implementations may interpret or adapt the proposed time step `dt` as
67 /// needed (e.g., for adaptive time stepping), and are free to update any
68 /// fields of the input required to continue the simulation.
69 ///
70 /// # Parameters
71 ///
72 /// - `state`: The current simulation state.
73 /// - `dt`: The proposed time step.
74 ///
75 /// # Returns
76 ///
77 /// The next input, computed from the current [`State`] and proposed `dt`.
78 ///
79 /// # Errors
80 ///
81 /// Returns a [`StepError`] if computing the next input fails.
82 fn advance_time(&mut self, state: &State<M>, dt: Duration)
83 -> Result<M::Input, Self::StepError>;
84
85 /// Advances the simulation by one step, starting from an initial input.
86 ///
87 /// This method first calls the model with the given input to compute
88 /// the initial output, forming a complete [`State`].
89 /// It then delegates to [`step_from_state`] to compute the next state.
90 /// As a result, the model is called twice: once to initialize the
91 /// state, and once after advancing.
92 ///
93 /// # Parameters
94 ///
95 /// - `input`: The model input at the start of the step.
96 /// - `dt`: The proposed time step.
97 ///
98 /// # Errors
99 ///
100 /// Returns a [`StepError`] if computing the next input or calling the model fails.
101 fn step(&mut self, input: M::Input, dt: Duration) -> Result<State<M>, Self::StepError>
102 where
103 M::Input: Clone,
104 {
105 let output = self.model().call(input.clone())?;
106 let state = State::new(input, output);
107
108 self.step_from_state(&state, dt)
109 }
110
111 /// Advances the simulation by one step from a known [`State`].
112 ///
113 /// This method computes the next input using [`advance_time`],
114 /// then calls the model to produce the resulting [`State`].
115 ///
116 /// # Parameters
117 ///
118 /// - `state`: The current simulation state.
119 /// - `dt`: The proposed time step.
120 ///
121 /// # Errors
122 ///
123 /// Returns a [`StepError`] if computing the next input or calling the model fails.
124 fn step_from_state(
125 &mut self,
126 state: &State<M>,
127 dt: Duration,
128 ) -> Result<State<M>, Self::StepError>
129 where
130 M::Input: Clone,
131 {
132 let input = self.advance_time(state, dt)?;
133 let output = self.model().call(input.clone())?;
134
135 Ok(State::new(input, output))
136 }
137
138 /// Runs the simulation for a fixed number of steps and collects the results.
139 ///
140 /// Starting from the given input, this method advances the simulation by
141 /// `steps` iterations using the proposed time step `dt`.
142 ///
143 /// # Parameters
144 ///
145 /// - `initial_input`: The model input at the start of the simulation.
146 /// - `steps`: The number of steps to run.
147 /// - `dt`: The proposed time step for each iteration.
148 ///
149 /// # Returns
150 ///
151 /// A `Vec` of length `steps + 1` containing each [`State`] computed during
152 /// the run, including the initial one.
153 ///
154 /// # Errors
155 ///
156 /// Returns a [`StepError`] if any step fails.
157 /// No further steps are taken after an error.
158 fn step_many(
159 &mut self,
160 initial_input: M::Input,
161 steps: usize,
162 dt: Duration,
163 ) -> Result<Vec<State<M>>, Self::StepError>
164 where
165 M::Input: Clone,
166 M::Output: Clone,
167 {
168 let mut results = Vec::with_capacity(steps + 1);
169
170 let output = self.model().call(initial_input.clone())?;
171 let mut current_state = State::new(initial_input, output);
172 results.push(current_state.clone());
173
174 for _ in 0..steps {
175 current_state = self.step_from_state(¤t_state, dt)?;
176 results.push(current_state.clone());
177 }
178
179 Ok(results)
180 }
181
182 /// Consumes the simulation and creates an iterator that advances it repeatedly.
183 ///
184 /// The iterator calls the simulation's stepping logic with a constant `dt`,
185 /// yielding each resulting [`State`] in sequence.
186 /// If a step fails, the error is returned and iteration stops.
187 ///
188 /// This method supports lazy or streaming evaluation and integrates cleanly
189 /// with iterator adapters such as `.take(n)`, `.map(...)`, or `.find(...)`.
190 /// It is memory-efficient and performs no intermediate allocations.
191 ///
192 /// # Parameters
193 ///
194 /// - `initial_input`: The model input at the start of the simulation.
195 /// - `dt`: The proposed time step for each iteration.
196 ///
197 /// # Returns
198 ///
199 /// An iterator over `Result<State<M>, StepError>` steps.
200 fn into_step_iter(
201 self,
202 initial_input: M::Input,
203 dt: Duration,
204 ) -> impl Iterator<Item = Result<State<M>, Self::StepError>>
205 where
206 M::Input: Clone,
207 M::Output: Clone,
208 {
209 StepIter {
210 dt,
211 known: Some(Known::Input(initial_input)),
212 sim: self,
213 }
214 }
215}
216
217/// Represents a snapshot of the simulation at a specific point in time.
218///
219/// A [`State`] pairs:
220/// - `input`: The independent variables, typically user-controlled or time-evolving.
221/// - `output`: The dependent variables, computed by the model.
222///
223/// Together, these describe the full state of the system at a given instant.
224#[derive(Debug, Default, PartialEq, PartialOrd)]
225pub struct State<M: Model> {
226 pub input: M::Input,
227 pub output: M::Output,
228}
229
230impl<M: Model> State<M> {
231 /// Creates a [`State`] from the provided input and output.
232 pub fn new(input: M::Input, output: M::Output) -> Self {
233 Self { input, output }
234 }
235}
236
237impl<M: Model> Clone for State<M>
238where
239 M::Input: Clone,
240 M::Output: Clone,
241{
242 fn clone(&self) -> Self {
243 Self {
244 input: self.input.clone(),
245 output: self.output.clone(),
246 }
247 }
248}
249
250impl<M: Model> Copy for State<M>
251where
252 M::Input: Copy,
253 M::Output: Copy,
254{
255}
256
257/// An iterator that repeatedly steps the simulation using a proposed time step.
258///
259/// Starting from an initial input, this iterator repeatedly steps the
260/// simulation using `dt`, yielding each resulting [`State`] as a `Result`.
261///
262/// If any step fails, the error is yielded and iteration stops.
263struct StepIter<M: Model, S: Simulation<M>> {
264 dt: Duration,
265 known: Option<Known<M>>,
266 sim: S,
267}
268
269/// Internal state held by the [`StepIter`] iterator.
270enum Known<M: Model> {
271 /// The simulation has only been initialized with an input.
272 Input(M::Input),
273 /// The full simulation state is available.
274 State(State<M>),
275}
276
277impl<M, S> Iterator for StepIter<M, S>
278where
279 M: Model,
280 S: Simulation<M>,
281 M::Input: Clone,
282 M::Output: Clone,
283{
284 type Item = Result<State<M>, S::StepError>;
285
286 /// Advances the simulation by one step.
287 ///
288 /// - If starting from an input, calls the model to produce the first state.
289 /// - If continuing from a full state, steps the simulation forward.
290 /// - On success, yields a new [`State`].
291 /// - On error, yields a [`StepError`] and ends the iteration.
292 fn next(&mut self) -> Option<Self::Item> {
293 let known = self.known.take()?;
294
295 match known {
296 // A full state exists - step forward from it.
297 Known::State(state) => match self.sim.step_from_state(&state, self.dt) {
298 Ok(state) => {
299 self.known = Some(Known::State(State::new(
300 state.input.clone(),
301 state.output.clone(),
302 )));
303 Some(Ok(state))
304 }
305 Err(error) => {
306 self.known = None;
307 Some(Err(error))
308 }
309 },
310
311 // Only the input is known - call the model and yield the first state.
312 Known::Input(input) => match self.sim.model().call(input.clone()) {
313 Ok(output) => {
314 self.known = Some(Known::State(State::new(input.clone(), output.clone())));
315 let state = State::new(input, output);
316 Some(Ok(state))
317 }
318 Err(error) => {
319 self.known = None;
320 Some(Err(error.into()))
321 }
322 },
323 }
324 }
325}
326
327/// Marks that iteration always ends after the first `None`.
328impl<M, S> FusedIterator for StepIter<M, S>
329where
330 M: Model,
331 S: Simulation<M>,
332 M::Input: Clone,
333 M::Output: Clone,
334{
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 use std::convert::Infallible;
342
343 use approx::{assert_abs_diff_eq, assert_relative_eq};
344 use thiserror::Error;
345
346 /// A simple spring-damper model used for simulation tests.
347 #[derive(Debug)]
348 struct Spring {
349 spring_constant: f64,
350 damping_coef: f64,
351 }
352
353 #[derive(Debug, Clone, Default, PartialEq)]
354 struct Input {
355 time_in_minutes: f64,
356 position: f64,
357 velocity: f64,
358 }
359
360 #[derive(Debug, Clone, PartialEq)]
361 struct Output {
362 acceleration: f64,
363 }
364
365 impl Model for Spring {
366 type Input = Input;
367 type Output = Output;
368 type Error = Infallible;
369
370 fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
371 let Input {
372 position, velocity, ..
373 } = input;
374
375 let acceleration = -self.spring_constant * position - self.damping_coef * velocity;
376
377 Ok(Output { acceleration })
378 }
379 }
380
381 #[derive(Debug)]
382 struct SpringSimulation {
383 model: Spring,
384 }
385
386 impl Simulation<Spring> for SpringSimulation {
387 type StepError = Infallible;
388
389 fn model(&self) -> &Spring {
390 &self.model
391 }
392
393 fn advance_time(
394 &mut self,
395 state: &State<Spring>,
396 dt: Duration,
397 ) -> Result<Input, Self::StepError> {
398 let seconds = dt.as_secs_f64();
399 let time_in_minutes = state.input.time_in_minutes + seconds / 60.0;
400
401 let position = state.input.position + state.input.velocity * seconds;
402 let velocity = state.input.velocity + state.output.acceleration * seconds;
403
404 Ok(Input {
405 time_in_minutes,
406 position,
407 velocity,
408 })
409 }
410 }
411
412 #[test]
413 fn zero_force_spring_has_constant_velocity() {
414 let mut sim = SpringSimulation {
415 model: Spring {
416 spring_constant: 0.0,
417 damping_coef: 0.0,
418 },
419 };
420
421 let initial = Input {
422 time_in_minutes: 0.0,
423 position: 10.0,
424 velocity: 2.0,
425 };
426
427 let steps = 3;
428 let dt = Duration::from_secs(30);
429
430 let states = sim.step_many(initial, steps, dt).unwrap();
431
432 let input_states: Vec<_> = states.iter().map(|s| s.input.clone()).collect();
433
434 assert_eq!(
435 input_states,
436 vec![
437 Input {
438 time_in_minutes: 0.0,
439 position: 10.0,
440 velocity: 2.0
441 },
442 Input {
443 time_in_minutes: 0.5,
444 position: 70.0,
445 velocity: 2.0
446 },
447 Input {
448 time_in_minutes: 1.0,
449 position: 130.0,
450 velocity: 2.0
451 },
452 Input {
453 time_in_minutes: 1.5,
454 position: 190.0,
455 velocity: 2.0
456 },
457 ]
458 );
459
460 assert!(
461 states.iter().all(|s| s.output.acceleration == 0.0),
462 "All accelerations should be zero"
463 );
464 }
465
466 #[test]
467 fn damped_spring_sim_converges_to_zero() {
468 let sim = SpringSimulation {
469 model: Spring {
470 spring_constant: 0.5,
471 damping_coef: 5.0,
472 },
473 };
474
475 let initial = Input {
476 position: 10.0,
477 ..Input::default()
478 };
479
480 let dt = Duration::from_millis(100);
481
482 let tolerance = 1e-4;
483 let max_steps = 5000;
484
485 let is_at_rest = |state: &State<Spring>| {
486 state.input.position.abs() < tolerance
487 && state.input.velocity.abs() < tolerance
488 && state.output.acceleration.abs() < tolerance
489 };
490
491 // Use the step iterator to find the first state close enough to zero.
492 let final_state = sim
493 .into_step_iter(initial, dt)
494 .take(max_steps)
495 .find_map(|res| match res {
496 Ok(state) if is_at_rest(&state) => Some(state),
497 Ok(_) => None,
498 Err(error) => panic!("Simulation error: {error:?}"),
499 })
500 .expect("Simulation did not reach a resting state within {max_steps} steps");
501
502 let State {
503 input: final_input,
504 output: final_output,
505 } = final_state;
506
507 assert_abs_diff_eq!(final_input.position, 0.0, epsilon = tolerance);
508 assert_abs_diff_eq!(final_input.velocity, 0.0, epsilon = tolerance);
509 assert_abs_diff_eq!(final_output.acceleration, 0.0, epsilon = tolerance);
510
511 assert_relative_eq!(final_input.time_in_minutes, 1.875, epsilon = tolerance);
512 }
513
514 /// A model that fails if the input exceeds a specified maximum.
515 #[derive(Debug)]
516 struct CheckInput {
517 max_value: usize,
518 }
519
520 impl Model for CheckInput {
521 type Input = usize;
522 type Output = ();
523 type Error = CheckInputError;
524
525 fn call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
526 if input <= self.max_value {
527 Ok(())
528 } else {
529 Err(CheckInputError(input, self.max_value))
530 }
531 }
532 }
533
534 #[derive(Debug, Error)]
535 #[error("{0} is bigger than max value of {1}")]
536 struct CheckInputError(usize, usize);
537
538 /// A test simulation using [`CheckInput`].
539 ///
540 /// Each step increments the input by 1.
541 /// Yields an error when the input exceeds the maximum threshold `N`.
542 #[derive(Debug)]
543 struct CheckInputSim<const N: usize>;
544
545 impl<const N: usize> Simulation<CheckInput> for CheckInputSim<N> {
546 type StepError = CheckInputError;
547
548 fn model(&self) -> &CheckInput {
549 &CheckInput { max_value: N }
550 }
551
552 fn advance_time(
553 &mut self,
554 state: &State<CheckInput>,
555 _dt: Duration,
556 ) -> Result<usize, Self::StepError> {
557 Ok(state.input + 1)
558 }
559 }
560
561 #[test]
562 fn step_iter_yields_error_correctly() {
563 let mut iter = CheckInputSim::<3>.into_step_iter(0, Duration::from_secs(1));
564
565 let state = iter
566 .next()
567 .expect("Initial call yields a result")
568 .expect("Initial call is a success");
569 assert_eq!(state.input, 0);
570
571 let state = iter
572 .next()
573 .expect("First step yields a result")
574 .expect("First step is a success");
575 assert_eq!(state.input, 1);
576
577 let state = iter
578 .next()
579 .expect("Second step yields a result")
580 .expect("Second step is a success");
581 assert_eq!(state.input, 2);
582
583 let state = iter
584 .next()
585 .expect("Third step yields a result")
586 .expect("Third step is a success");
587 assert_eq!(state.input, 3);
588
589 let error = iter
590 .next()
591 .expect("Fourth step yields a result")
592 .expect_err("Fourth step is an error");
593 assert_eq!(format!("{error}"), "4 is bigger than max value of 3");
594
595 assert!(iter.next().is_none());
596 }
597}