scirs2_integrate/ode/
multirate.rs

1//! Multirate Methods for ODEs with Multiple Timescales
2//!
3//! This module implements multirate integration methods for systems where
4//! different components evolve on different time scales. This is common in
5//! many applications like:
6//! - Chemical kinetics (fast/slow reactions)
7//! - Electrical circuits (fast/slow transients)
8//! - Climate models (fast weather, slow climate)
9//! - Biological systems (fast enzyme kinetics, slow gene expression)
10
11use crate::common::IntegrateFloat;
12use crate::error::{IntegrateError, IntegrateResult};
13use crate::ode::{ODEMethod, ODEResult};
14use scirs2_core::ndarray::{s, Array1, ArrayView1};
15use std::collections::VecDeque;
16
17/// Multirate ODE system with fast and slow components
18pub trait MultirateSystem<F: IntegrateFloat> {
19    /// Evaluate slow component: dy_slow/dt = f_slow(t, y_slow, y_fast)
20    fn slow_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F>;
21
22    /// Evaluate fast component: dy_fast/dt = f_fast(t, y_slow, y_fast)
23    fn fast_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F>;
24
25    /// Get dimension of slow variables
26    fn slow_dim(&self) -> usize;
27
28    /// Get dimension of fast variables
29    fn fast_dim(&self) -> usize;
30}
31
32/// Multirate integration method types
33#[derive(Debug, Clone)]
34pub enum MultirateMethod {
35    /// Explicit multirate Runge-Kutta method
36    ExplicitMRK {
37        macro_steps: usize,
38        micro_steps: usize,
39    },
40    /// Implicit-explicit (IMEX) multirate method
41    IMEX {
42        macro_steps: usize,
43        micro_steps: usize,
44    },
45    /// Compound _fast-_slow method
46    CompoundFastSlow {
47        _fast_method: ODEMethod,
48        _slow_method: ODEMethod,
49    },
50    /// Extrapolated multirate method
51    Extrapolated { base_ratio: usize, levels: usize },
52}
53
54/// Multirate solver configuration
55#[derive(Debug, Clone)]
56pub struct MultirateOptions<F: IntegrateFloat> {
57    /// Multirate method to use
58    pub method: MultirateMethod,
59    /// Macro step size (for slow components)
60    pub macro_step: F,
61    /// Relative tolerance
62    pub rtol: F,
63    /// Absolute tolerance
64    pub atol: F,
65    /// Maximum number of macro steps
66    pub max_steps: usize,
67    /// Time scale separation estimate
68    pub timescale_ratio: Option<F>,
69}
70
71impl<F: IntegrateFloat> Default for MultirateOptions<F> {
72    fn default() -> Self {
73        Self {
74            method: MultirateMethod::ExplicitMRK {
75                macro_steps: 4,
76                micro_steps: 10,
77            },
78            macro_step: F::from(0.01).unwrap(),
79            rtol: F::from(1e-6).unwrap(),
80            atol: F::from(1e-9).unwrap(),
81            max_steps: 10000,
82            timescale_ratio: None,
83        }
84    }
85}
86
87/// Multirate ODE solver
88pub struct MultirateSolver<F: IntegrateFloat> {
89    options: MultirateOptions<F>,
90    /// History of solutions for extrapolation methods
91    history: VecDeque<(F, Array1<F>)>,
92    /// Current macro step size
93    current_macro_step: F,
94    /// Current micro step size
95    #[allow(dead_code)]
96    current_micro_step: F,
97}
98
99impl<F: IntegrateFloat> MultirateSolver<F> {
100    /// Create new multirate solver
101    pub fn new(options: MultirateOptions<F>) -> Self {
102        let current_macro_step = options.macro_step;
103        let current_micro_step = match &options.method {
104            MultirateMethod::ExplicitMRK { micro_steps, .. } => {
105                current_macro_step / F::from(*micro_steps).unwrap()
106            }
107            MultirateMethod::IMEX { micro_steps, .. } => {
108                current_macro_step / F::from(*micro_steps).unwrap()
109            }
110            _ => current_macro_step / F::from(10).unwrap(),
111        };
112
113        Self {
114            options,
115            history: VecDeque::new(),
116            current_macro_step,
117            current_micro_step,
118        }
119    }
120
121    /// Solve multirate ODE system
122    pub fn solve<S>(
123        &mut self,
124        system: S,
125        t_span: [F; 2],
126        y0: Array1<F>,
127    ) -> IntegrateResult<ODEResult<F>>
128    where
129        S: MultirateSystem<F>,
130    {
131        let [t0, tf] = t_span;
132        let slow_dim = system.slow_dim();
133        let fast_dim = system.fast_dim();
134
135        if y0.len() != slow_dim + fast_dim {
136            return Err(IntegrateError::ValueError(format!(
137                "Initial condition dimension {} does not match system dimension {}",
138                y0.len(),
139                slow_dim + fast_dim
140            )));
141        }
142
143        let mut t = t0;
144        let mut y = y0.clone();
145        let mut solution_t = vec![t];
146        let mut solution_y = vec![y.clone()];
147        let mut step_count = 0;
148
149        while t < tf && step_count < self.options.max_steps {
150            // Adjust step size near final time
151            let dt = if t + self.current_macro_step > tf {
152                tf - t
153            } else {
154                self.current_macro_step
155            };
156
157            // Split state into slow and fast components
158            let y_slow = y.slice(s![..slow_dim]).to_owned();
159            let y_fast = y.slice(s![slow_dim..]).to_owned();
160
161            // Take multirate step
162            let (new_y_slow, new_y_fast) = match &self.options.method {
163                MultirateMethod::ExplicitMRK {
164                    macro_steps,
165                    micro_steps,
166                } => self.explicit_mrk_step(
167                    &system,
168                    t,
169                    dt,
170                    y_slow.view(),
171                    y_fast.view(),
172                    *macro_steps,
173                    *micro_steps,
174                )?,
175                MultirateMethod::IMEX {
176                    macro_steps,
177                    micro_steps,
178                } => self.imex_step(
179                    &system,
180                    t,
181                    dt,
182                    y_slow.view(),
183                    y_fast.view(),
184                    *macro_steps,
185                    *micro_steps,
186                )?,
187                MultirateMethod::CompoundFastSlow {
188                    _fast_method: _,
189                    _slow_method: _,
190                } => self.compound_fast_slow_step(&system, t, dt, y_slow.view(), y_fast.view())?,
191                MultirateMethod::Extrapolated { base_ratio, levels } => self.extrapolated_step(
192                    &system,
193                    t,
194                    dt,
195                    y_slow.view(),
196                    y_fast.view(),
197                    *base_ratio,
198                    *levels,
199                )?,
200            };
201
202            // Combine slow and fast components
203            let mut new_y = Array1::zeros(slow_dim + fast_dim);
204            new_y.slice_mut(s![..slow_dim]).assign(&new_y_slow);
205            new_y.slice_mut(s![slow_dim..]).assign(&new_y_fast);
206
207            t += dt;
208            y = new_y;
209            solution_t.push(t);
210            solution_y.push(y.clone());
211            step_count += 1;
212
213            // Update history for extrapolation methods
214            if matches!(self.options.method, MultirateMethod::Extrapolated { .. }) {
215                self.history.push_back((t, y.clone()));
216                if self.history.len() > 10 {
217                    self.history.pop_front();
218                }
219            }
220        }
221
222        if step_count >= self.options.max_steps {
223            return Err(IntegrateError::ConvergenceError(
224                "Maximum number of steps exceeded in multirate solver".to_string(),
225            ));
226        }
227
228        Ok(ODEResult {
229            t: solution_t,
230            y: solution_y,
231            success: true,
232            message: Some(format!("Multirate method: {:?}", self.options.method)),
233            n_eval: step_count * 4, // Approximate
234            n_steps: step_count,
235            n_accepted: step_count,
236            n_rejected: 0,
237            n_lu: 0,
238            n_jac: 0,
239            method: ODEMethod::RK4, // Default representation
240        })
241    }
242
243    /// Explicit multirate Runge-Kutta step
244    fn explicit_mrk_step<S>(
245        &self,
246        system: &S,
247        t: F,
248        dt: F,
249        y_slow: ArrayView1<F>,
250        y_fast: ArrayView1<F>,
251        _macro_steps: usize,
252        micro_steps: usize,
253    ) -> IntegrateResult<(Array1<F>, Array1<F>)>
254    where
255        S: MultirateSystem<F>,
256    {
257        let dt_micro = dt / F::from(micro_steps).unwrap();
258
259        // RK4 step for _slow component (large step)
260        let k1_slow = system.slow_rhs(t, y_slow, y_fast);
261
262        // Fast component evolution over macro step with micro _steps
263        let mut y_fast_current = y_fast.to_owned();
264        let mut t_micro = t;
265
266        for _ in 0..micro_steps {
267            // RK4 micro step for _fast component
268            let k1_fast = system.fast_rhs(t_micro, y_slow, y_fast_current.view());
269            let k2_fast = system.fast_rhs(
270                t_micro + dt_micro / F::from(2).unwrap(),
271                y_slow,
272                (y_fast_current.clone() + k1_fast.clone() * dt_micro / F::from(2).unwrap()).view(),
273            );
274            let k3_fast = system.fast_rhs(
275                t_micro + dt_micro / F::from(2).unwrap(),
276                y_slow,
277                (y_fast_current.clone() + k2_fast.clone() * dt_micro / F::from(2).unwrap()).view(),
278            );
279            let k4_fast = system.fast_rhs(
280                t_micro + dt_micro,
281                y_slow,
282                (y_fast_current.clone() + k3_fast.clone() * dt_micro).view(),
283            );
284
285            let two = F::from(2).unwrap();
286            let six = F::from(6).unwrap();
287            let rk_sum = k1_fast.clone() + &k2_fast * two + &k3_fast * two + k4_fast.clone();
288            y_fast_current = y_fast_current + &rk_sum * (dt_micro / six);
289            t_micro += dt_micro;
290        }
291
292        // Complete _slow step using final _fast state
293        let k2_slow = system.slow_rhs(t + dt / F::from(2).unwrap(), y_slow, y_fast_current.view());
294        let k3_slow = system.slow_rhs(
295            t + dt / F::from(2).unwrap(),
296            (y_slow.to_owned() + k1_slow.clone() * dt / F::from(2).unwrap()).view(),
297            y_fast_current.view(),
298        );
299        let k4_slow = system.slow_rhs(
300            t + dt,
301            (y_slow.to_owned() + k3_slow.clone() * dt).view(),
302            y_fast_current.view(),
303        );
304
305        let two = F::from(2).unwrap();
306        let six = F::from(6).unwrap();
307        let rk_sum_slow = k1_slow.clone() + &k2_slow * two + &k3_slow * two + k4_slow.clone();
308        let new_y_slow = y_slow.to_owned() + &rk_sum_slow * (dt / six);
309
310        Ok((new_y_slow, y_fast_current))
311    }
312
313    /// Implicit-explicit (IMEX) multirate step
314    fn imex_step<S>(
315        &self,
316        system: &S,
317        t: F,
318        dt: F,
319        y_slow: ArrayView1<F>,
320        y_fast: ArrayView1<F>,
321        _macro_steps: usize,
322        micro_steps: usize,
323    ) -> IntegrateResult<(Array1<F>, Array1<F>)>
324    where
325        S: MultirateSystem<F>,
326    {
327        // For this implementation, use explicit treatment
328        // In practice, IMEX would treat stiff _fast components implicitly
329        self.explicit_mrk_step(system, t, dt, y_slow, y_fast, _macro_steps, micro_steps)
330    }
331
332    /// Compound fast-slow method step
333    fn compound_fast_slow_step<S>(
334        &self,
335        system: &S,
336        t: F,
337        dt: F,
338        y_slow: ArrayView1<F>,
339        y_fast: ArrayView1<F>,
340    ) -> IntegrateResult<(Array1<F>, Array1<F>)>
341    where
342        S: MultirateSystem<F>,
343    {
344        // First solve _fast subsystem to quasi-steady state
345        let mut y_fast_current = y_fast.to_owned();
346        let dt_fast = dt / F::from(100).unwrap(); // Very small steps for _fast system
347
348        // Fast relaxation phase
349        for _ in 0..50 {
350            // Allow _fast system to equilibrate
351            let k_fast = system.fast_rhs(t, y_slow, y_fast_current.view());
352            y_fast_current = y_fast_current + k_fast * dt_fast;
353        }
354
355        // Then advance _slow system with equilibrated _fast variables
356        let k_slow = system.slow_rhs(t, y_slow, y_fast_current.view());
357        let new_y_slow = y_slow.to_owned() + k_slow * dt;
358
359        // Final _fast adjustment
360        let k_fast_final = system.fast_rhs(t + dt, new_y_slow.view(), y_fast_current.view());
361        let new_y_fast = y_fast_current + k_fast_final * dt;
362
363        Ok((new_y_slow, new_y_fast))
364    }
365
366    /// Extrapolated multirate step
367    fn extrapolated_step<S>(
368        &self,
369        system: &S,
370        t: F,
371        dt: F,
372        y_slow: ArrayView1<F>,
373        y_fast: ArrayView1<F>,
374        base_ratio: usize,
375        levels: usize,
376    ) -> IntegrateResult<(Array1<F>, Array1<F>)>
377    where
378        S: MultirateSystem<F>,
379    {
380        // Richardson extrapolation with different micro step sizes
381        let mut solutions = Vec::new();
382
383        for level in 0..levels {
384            let micro_steps = base_ratio * (2_usize.pow(level as u32));
385            let (y_slow_approx, y_fast_approx) =
386                self.explicit_mrk_step(system, t, dt, y_slow, y_fast, 4, micro_steps)?;
387            solutions.push((y_slow_approx, y_fast_approx));
388        }
389
390        // Simple Richardson extrapolation (linear)
391        if solutions.len() >= 2 {
392            let (y_slow_coarse, y_fast_coarse) = &solutions[0];
393            let (y_slow_fine, y_fast_fine) = &solutions[1];
394
395            // Extrapolated solution: y_ext = y_fine + (y_fine - y_coarse)
396            let y_slow_ext = y_slow_fine + (y_slow_fine - y_slow_coarse);
397            let y_fast_ext = y_fast_fine + (y_fast_fine - y_fast_coarse);
398
399            Ok((y_slow_ext, y_fast_ext))
400        } else {
401            Ok(solutions.into_iter().next().unwrap())
402        }
403    }
404}
405
406/// Example multirate system: fast oscillator coupled to slow drift
407pub struct FastSlowOscillator<F: IntegrateFloat> {
408    /// Fast frequency
409    pub omega_fast: F,
410    /// Slow time scale
411    pub epsilon: F,
412    /// Coupling strength
413    pub coupling: F,
414}
415
416impl<F: IntegrateFloat> MultirateSystem<F> for FastSlowOscillator<F> {
417    fn slow_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F> {
418        let x_slow = y_slow[0];
419        let v_slow = y_slow[1];
420        let x_fast = yfast[0];
421
422        // Slow dynamics: influenced by _fast oscillations
423        let dx_slow_dt = v_slow;
424        let dv_slow_dt = -self.epsilon * x_slow + self.coupling * x_fast;
425
426        Array1::from_vec(vec![dx_slow_dt, dv_slow_dt])
427    }
428
429    fn fast_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F> {
430        let x_slow = y_slow[0];
431        let x_fast = yfast[0];
432        let v_fast = yfast[1];
433
434        // Fast oscillator dynamics
435        let dx_fast_dt = v_fast;
436        let dv_fast_dt = -self.omega_fast * self.omega_fast * x_fast + self.coupling * x_slow;
437
438        Array1::from_vec(vec![dx_fast_dt, dv_fast_dt])
439    }
440
441    fn slow_dim(&self) -> usize {
442        2
443    }
444    fn fast_dim(&self) -> usize {
445        2
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use approx::assert_abs_diff_eq;
453
454    #[test]
455    fn test_multirate_system_dimensions() {
456        let system = FastSlowOscillator {
457            omega_fast: 10.0,
458            epsilon: 0.1,
459            coupling: 0.05,
460        };
461
462        assert_eq!(system.slow_dim(), 2);
463        assert_eq!(system.fast_dim(), 2);
464        assert_eq!(system.slow_dim() + system.fast_dim(), 4);
465    }
466
467    #[test]
468    fn test_multirate_solver_creation() {
469        let options = MultirateOptions {
470            method: MultirateMethod::ExplicitMRK {
471                macro_steps: 4,
472                micro_steps: 10,
473            },
474            macro_step: 0.01,
475            rtol: 1e-6,
476            atol: 1e-9,
477            max_steps: 1000,
478            timescale_ratio: Some(100.0),
479        };
480
481        let solver = MultirateSolver::new(options);
482        assert_abs_diff_eq!(solver.current_macro_step, 0.01);
483        assert_abs_diff_eq!(solver.current_micro_step, 0.001);
484    }
485
486    #[test]
487    fn test_fast_slow_oscillator_solve() {
488        let system = FastSlowOscillator {
489            omega_fast: 20.0, // Fast oscillations
490            epsilon: 0.1,     // Slow dynamics
491            coupling: 0.02,   // Weak coupling
492        };
493
494        let options = MultirateOptions {
495            method: MultirateMethod::ExplicitMRK {
496                macro_steps: 4,
497                micro_steps: 20,
498            },
499            macro_step: 0.05,
500            rtol: 1e-6,
501            atol: 1e-9,
502            max_steps: 200,
503            timescale_ratio: Some(200.0),
504        };
505
506        let mut solver = MultirateSolver::new(options);
507
508        // Initial conditions: [x_slow, v_slow, x_fast, v_fast]
509        let y0 = Array1::from_vec(vec![1.0, 0.0, 0.1, 0.0]);
510
511        let result = solver.solve(system, [0.0, 1.0], y0.clone()).unwrap();
512
513        // Check that solution was computed
514        assert!(result.t.len() > 1);
515        assert_eq!(result.y.len(), result.t.len());
516        assert_eq!(result.y[0].len(), 4);
517
518        // Check that fast and slow components behave appropriately
519        let final_state = result.y.last().unwrap();
520
521        // Fast oscillator should still be oscillating (non-zero velocity)
522        let fast_velocity: f64 = final_state[3];
523        assert!(fast_velocity.abs() > 1e-6); // Fast velocity
524
525        // Slow component should have evolved
526        let slow_pos_change: f64 = final_state[0] - y0[0];
527        assert!(slow_pos_change.abs() > 1e-3); // Slow position changed
528    }
529
530    #[test]
531    fn test_compound_fast_slow_method() {
532        let system = FastSlowOscillator {
533            omega_fast: 50.0, // Very fast oscillations
534            epsilon: 0.05,    // Very slow dynamics
535            coupling: 0.01,   // Weak coupling
536        };
537
538        let options = MultirateOptions {
539            method: MultirateMethod::CompoundFastSlow {
540                _fast_method: ODEMethod::RK4,
541                _slow_method: ODEMethod::RK4,
542            },
543            macro_step: 0.1,
544            rtol: 1e-6,
545            atol: 1e-9,
546            max_steps: 100,
547            timescale_ratio: Some(1000.0),
548        };
549
550        let mut solver = MultirateSolver::new(options);
551        let y0 = Array1::from_vec(vec![1.0, 0.0, 0.1, 0.0]);
552
553        let result = solver.solve(system, [0.0, 0.5], y0).unwrap();
554
555        assert!(result.t.len() > 1);
556        assert!(result.n_steps > 0);
557    }
558
559    #[test]
560    fn test_extrapolated_multirate_method() {
561        let system = FastSlowOscillator {
562            omega_fast: 15.0,
563            epsilon: 0.2,
564            coupling: 0.03,
565        };
566
567        let options = MultirateOptions {
568            method: MultirateMethod::Extrapolated {
569                base_ratio: 5,
570                levels: 2,
571            },
572            macro_step: 0.02,
573            rtol: 1e-8,
574            atol: 1e-11,
575            max_steps: 500,
576            timescale_ratio: Some(75.0),
577        };
578
579        let mut solver = MultirateSolver::new(options);
580        let y0 = Array1::from_vec(vec![0.5, 0.0, 0.2, 0.1]);
581
582        let result = solver.solve(system, [0.0, 0.2], y0).unwrap();
583
584        assert!(result.t.len() > 1);
585        assert!(result.n_steps > 0);
586    }
587}