1use crate::common::IntegrateFloat;
12use crate::error::{IntegrateError, IntegrateResult};
13use crate::ode::{ODEMethod, ODEResult};
14use scirs2_core::ndarray::{s, Array1, ArrayView1};
15use std::collections::VecDeque;
16
17pub trait MultirateSystem<F: IntegrateFloat> {
19 fn slow_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F>;
21
22 fn fast_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F>;
24
25 fn slow_dim(&self) -> usize;
27
28 fn fast_dim(&self) -> usize;
30}
31
32#[derive(Debug, Clone)]
34pub enum MultirateMethod {
35 ExplicitMRK {
37 macro_steps: usize,
38 micro_steps: usize,
39 },
40 IMEX {
42 macro_steps: usize,
43 micro_steps: usize,
44 },
45 CompoundFastSlow {
47 _fast_method: ODEMethod,
48 _slow_method: ODEMethod,
49 },
50 Extrapolated { base_ratio: usize, levels: usize },
52}
53
54#[derive(Debug, Clone)]
56pub struct MultirateOptions<F: IntegrateFloat> {
57 pub method: MultirateMethod,
59 pub macro_step: F,
61 pub rtol: F,
63 pub atol: F,
65 pub max_steps: usize,
67 pub timescale_ratio: Option<F>,
69}
70
71impl<F: IntegrateFloat> Default for MultirateOptions<F> {
72 fn default() -> Self {
73 Self {
74 method: MultirateMethod::ExplicitMRK {
75 macro_steps: 4,
76 micro_steps: 10,
77 },
78 macro_step: F::from(0.01).unwrap(),
79 rtol: F::from(1e-6).unwrap(),
80 atol: F::from(1e-9).unwrap(),
81 max_steps: 10000,
82 timescale_ratio: None,
83 }
84 }
85}
86
87pub struct MultirateSolver<F: IntegrateFloat> {
89 options: MultirateOptions<F>,
90 history: VecDeque<(F, Array1<F>)>,
92 current_macro_step: F,
94 #[allow(dead_code)]
96 current_micro_step: F,
97}
98
99impl<F: IntegrateFloat> MultirateSolver<F> {
100 pub fn new(options: MultirateOptions<F>) -> Self {
102 let current_macro_step = options.macro_step;
103 let current_micro_step = match &options.method {
104 MultirateMethod::ExplicitMRK { micro_steps, .. } => {
105 current_macro_step / F::from(*micro_steps).unwrap()
106 }
107 MultirateMethod::IMEX { micro_steps, .. } => {
108 current_macro_step / F::from(*micro_steps).unwrap()
109 }
110 _ => current_macro_step / F::from(10).unwrap(),
111 };
112
113 Self {
114 options,
115 history: VecDeque::new(),
116 current_macro_step,
117 current_micro_step,
118 }
119 }
120
121 pub fn solve<S>(
123 &mut self,
124 system: S,
125 t_span: [F; 2],
126 y0: Array1<F>,
127 ) -> IntegrateResult<ODEResult<F>>
128 where
129 S: MultirateSystem<F>,
130 {
131 let [t0, tf] = t_span;
132 let slow_dim = system.slow_dim();
133 let fast_dim = system.fast_dim();
134
135 if y0.len() != slow_dim + fast_dim {
136 return Err(IntegrateError::ValueError(format!(
137 "Initial condition dimension {} does not match system dimension {}",
138 y0.len(),
139 slow_dim + fast_dim
140 )));
141 }
142
143 let mut t = t0;
144 let mut y = y0.clone();
145 let mut solution_t = vec![t];
146 let mut solution_y = vec![y.clone()];
147 let mut step_count = 0;
148
149 while t < tf && step_count < self.options.max_steps {
150 let dt = if t + self.current_macro_step > tf {
152 tf - t
153 } else {
154 self.current_macro_step
155 };
156
157 let y_slow = y.slice(s![..slow_dim]).to_owned();
159 let y_fast = y.slice(s![slow_dim..]).to_owned();
160
161 let (new_y_slow, new_y_fast) = match &self.options.method {
163 MultirateMethod::ExplicitMRK {
164 macro_steps,
165 micro_steps,
166 } => self.explicit_mrk_step(
167 &system,
168 t,
169 dt,
170 y_slow.view(),
171 y_fast.view(),
172 *macro_steps,
173 *micro_steps,
174 )?,
175 MultirateMethod::IMEX {
176 macro_steps,
177 micro_steps,
178 } => self.imex_step(
179 &system,
180 t,
181 dt,
182 y_slow.view(),
183 y_fast.view(),
184 *macro_steps,
185 *micro_steps,
186 )?,
187 MultirateMethod::CompoundFastSlow {
188 _fast_method: _,
189 _slow_method: _,
190 } => self.compound_fast_slow_step(&system, t, dt, y_slow.view(), y_fast.view())?,
191 MultirateMethod::Extrapolated { base_ratio, levels } => self.extrapolated_step(
192 &system,
193 t,
194 dt,
195 y_slow.view(),
196 y_fast.view(),
197 *base_ratio,
198 *levels,
199 )?,
200 };
201
202 let mut new_y = Array1::zeros(slow_dim + fast_dim);
204 new_y.slice_mut(s![..slow_dim]).assign(&new_y_slow);
205 new_y.slice_mut(s![slow_dim..]).assign(&new_y_fast);
206
207 t += dt;
208 y = new_y;
209 solution_t.push(t);
210 solution_y.push(y.clone());
211 step_count += 1;
212
213 if matches!(self.options.method, MultirateMethod::Extrapolated { .. }) {
215 self.history.push_back((t, y.clone()));
216 if self.history.len() > 10 {
217 self.history.pop_front();
218 }
219 }
220 }
221
222 if step_count >= self.options.max_steps {
223 return Err(IntegrateError::ConvergenceError(
224 "Maximum number of steps exceeded in multirate solver".to_string(),
225 ));
226 }
227
228 Ok(ODEResult {
229 t: solution_t,
230 y: solution_y,
231 success: true,
232 message: Some(format!("Multirate method: {:?}", self.options.method)),
233 n_eval: step_count * 4, n_steps: step_count,
235 n_accepted: step_count,
236 n_rejected: 0,
237 n_lu: 0,
238 n_jac: 0,
239 method: ODEMethod::RK4, })
241 }
242
243 fn explicit_mrk_step<S>(
245 &self,
246 system: &S,
247 t: F,
248 dt: F,
249 y_slow: ArrayView1<F>,
250 y_fast: ArrayView1<F>,
251 _macro_steps: usize,
252 micro_steps: usize,
253 ) -> IntegrateResult<(Array1<F>, Array1<F>)>
254 where
255 S: MultirateSystem<F>,
256 {
257 let dt_micro = dt / F::from(micro_steps).unwrap();
258
259 let k1_slow = system.slow_rhs(t, y_slow, y_fast);
261
262 let mut y_fast_current = y_fast.to_owned();
264 let mut t_micro = t;
265
266 for _ in 0..micro_steps {
267 let k1_fast = system.fast_rhs(t_micro, y_slow, y_fast_current.view());
269 let k2_fast = system.fast_rhs(
270 t_micro + dt_micro / F::from(2).unwrap(),
271 y_slow,
272 (y_fast_current.clone() + k1_fast.clone() * dt_micro / F::from(2).unwrap()).view(),
273 );
274 let k3_fast = system.fast_rhs(
275 t_micro + dt_micro / F::from(2).unwrap(),
276 y_slow,
277 (y_fast_current.clone() + k2_fast.clone() * dt_micro / F::from(2).unwrap()).view(),
278 );
279 let k4_fast = system.fast_rhs(
280 t_micro + dt_micro,
281 y_slow,
282 (y_fast_current.clone() + k3_fast.clone() * dt_micro).view(),
283 );
284
285 let two = F::from(2).unwrap();
286 let six = F::from(6).unwrap();
287 let rk_sum = k1_fast.clone() + &k2_fast * two + &k3_fast * two + k4_fast.clone();
288 y_fast_current = y_fast_current + &rk_sum * (dt_micro / six);
289 t_micro += dt_micro;
290 }
291
292 let k2_slow = system.slow_rhs(t + dt / F::from(2).unwrap(), y_slow, y_fast_current.view());
294 let k3_slow = system.slow_rhs(
295 t + dt / F::from(2).unwrap(),
296 (y_slow.to_owned() + k1_slow.clone() * dt / F::from(2).unwrap()).view(),
297 y_fast_current.view(),
298 );
299 let k4_slow = system.slow_rhs(
300 t + dt,
301 (y_slow.to_owned() + k3_slow.clone() * dt).view(),
302 y_fast_current.view(),
303 );
304
305 let two = F::from(2).unwrap();
306 let six = F::from(6).unwrap();
307 let rk_sum_slow = k1_slow.clone() + &k2_slow * two + &k3_slow * two + k4_slow.clone();
308 let new_y_slow = y_slow.to_owned() + &rk_sum_slow * (dt / six);
309
310 Ok((new_y_slow, y_fast_current))
311 }
312
313 fn imex_step<S>(
315 &self,
316 system: &S,
317 t: F,
318 dt: F,
319 y_slow: ArrayView1<F>,
320 y_fast: ArrayView1<F>,
321 _macro_steps: usize,
322 micro_steps: usize,
323 ) -> IntegrateResult<(Array1<F>, Array1<F>)>
324 where
325 S: MultirateSystem<F>,
326 {
327 self.explicit_mrk_step(system, t, dt, y_slow, y_fast, _macro_steps, micro_steps)
330 }
331
332 fn compound_fast_slow_step<S>(
334 &self,
335 system: &S,
336 t: F,
337 dt: F,
338 y_slow: ArrayView1<F>,
339 y_fast: ArrayView1<F>,
340 ) -> IntegrateResult<(Array1<F>, Array1<F>)>
341 where
342 S: MultirateSystem<F>,
343 {
344 let mut y_fast_current = y_fast.to_owned();
346 let dt_fast = dt / F::from(100).unwrap(); for _ in 0..50 {
350 let k_fast = system.fast_rhs(t, y_slow, y_fast_current.view());
352 y_fast_current = y_fast_current + k_fast * dt_fast;
353 }
354
355 let k_slow = system.slow_rhs(t, y_slow, y_fast_current.view());
357 let new_y_slow = y_slow.to_owned() + k_slow * dt;
358
359 let k_fast_final = system.fast_rhs(t + dt, new_y_slow.view(), y_fast_current.view());
361 let new_y_fast = y_fast_current + k_fast_final * dt;
362
363 Ok((new_y_slow, new_y_fast))
364 }
365
366 fn extrapolated_step<S>(
368 &self,
369 system: &S,
370 t: F,
371 dt: F,
372 y_slow: ArrayView1<F>,
373 y_fast: ArrayView1<F>,
374 base_ratio: usize,
375 levels: usize,
376 ) -> IntegrateResult<(Array1<F>, Array1<F>)>
377 where
378 S: MultirateSystem<F>,
379 {
380 let mut solutions = Vec::new();
382
383 for level in 0..levels {
384 let micro_steps = base_ratio * (2_usize.pow(level as u32));
385 let (y_slow_approx, y_fast_approx) =
386 self.explicit_mrk_step(system, t, dt, y_slow, y_fast, 4, micro_steps)?;
387 solutions.push((y_slow_approx, y_fast_approx));
388 }
389
390 if solutions.len() >= 2 {
392 let (y_slow_coarse, y_fast_coarse) = &solutions[0];
393 let (y_slow_fine, y_fast_fine) = &solutions[1];
394
395 let y_slow_ext = y_slow_fine + (y_slow_fine - y_slow_coarse);
397 let y_fast_ext = y_fast_fine + (y_fast_fine - y_fast_coarse);
398
399 Ok((y_slow_ext, y_fast_ext))
400 } else {
401 Ok(solutions.into_iter().next().unwrap())
402 }
403 }
404}
405
406pub struct FastSlowOscillator<F: IntegrateFloat> {
408 pub omega_fast: F,
410 pub epsilon: F,
412 pub coupling: F,
414}
415
416impl<F: IntegrateFloat> MultirateSystem<F> for FastSlowOscillator<F> {
417 fn slow_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F> {
418 let x_slow = y_slow[0];
419 let v_slow = y_slow[1];
420 let x_fast = yfast[0];
421
422 let dx_slow_dt = v_slow;
424 let dv_slow_dt = -self.epsilon * x_slow + self.coupling * x_fast;
425
426 Array1::from_vec(vec![dx_slow_dt, dv_slow_dt])
427 }
428
429 fn fast_rhs(&self, t: F, y_slow: ArrayView1<F>, yfast: ArrayView1<F>) -> Array1<F> {
430 let x_slow = y_slow[0];
431 let x_fast = yfast[0];
432 let v_fast = yfast[1];
433
434 let dx_fast_dt = v_fast;
436 let dv_fast_dt = -self.omega_fast * self.omega_fast * x_fast + self.coupling * x_slow;
437
438 Array1::from_vec(vec![dx_fast_dt, dv_fast_dt])
439 }
440
441 fn slow_dim(&self) -> usize {
442 2
443 }
444 fn fast_dim(&self) -> usize {
445 2
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452 use approx::assert_abs_diff_eq;
453
454 #[test]
455 fn test_multirate_system_dimensions() {
456 let system = FastSlowOscillator {
457 omega_fast: 10.0,
458 epsilon: 0.1,
459 coupling: 0.05,
460 };
461
462 assert_eq!(system.slow_dim(), 2);
463 assert_eq!(system.fast_dim(), 2);
464 assert_eq!(system.slow_dim() + system.fast_dim(), 4);
465 }
466
467 #[test]
468 fn test_multirate_solver_creation() {
469 let options = MultirateOptions {
470 method: MultirateMethod::ExplicitMRK {
471 macro_steps: 4,
472 micro_steps: 10,
473 },
474 macro_step: 0.01,
475 rtol: 1e-6,
476 atol: 1e-9,
477 max_steps: 1000,
478 timescale_ratio: Some(100.0),
479 };
480
481 let solver = MultirateSolver::new(options);
482 assert_abs_diff_eq!(solver.current_macro_step, 0.01);
483 assert_abs_diff_eq!(solver.current_micro_step, 0.001);
484 }
485
486 #[test]
487 fn test_fast_slow_oscillator_solve() {
488 let system = FastSlowOscillator {
489 omega_fast: 20.0, epsilon: 0.1, coupling: 0.02, };
493
494 let options = MultirateOptions {
495 method: MultirateMethod::ExplicitMRK {
496 macro_steps: 4,
497 micro_steps: 20,
498 },
499 macro_step: 0.05,
500 rtol: 1e-6,
501 atol: 1e-9,
502 max_steps: 200,
503 timescale_ratio: Some(200.0),
504 };
505
506 let mut solver = MultirateSolver::new(options);
507
508 let y0 = Array1::from_vec(vec![1.0, 0.0, 0.1, 0.0]);
510
511 let result = solver.solve(system, [0.0, 1.0], y0.clone()).unwrap();
512
513 assert!(result.t.len() > 1);
515 assert_eq!(result.y.len(), result.t.len());
516 assert_eq!(result.y[0].len(), 4);
517
518 let final_state = result.y.last().unwrap();
520
521 let fast_velocity: f64 = final_state[3];
523 assert!(fast_velocity.abs() > 1e-6); let slow_pos_change: f64 = final_state[0] - y0[0];
527 assert!(slow_pos_change.abs() > 1e-3); }
529
530 #[test]
531 fn test_compound_fast_slow_method() {
532 let system = FastSlowOscillator {
533 omega_fast: 50.0, epsilon: 0.05, coupling: 0.01, };
537
538 let options = MultirateOptions {
539 method: MultirateMethod::CompoundFastSlow {
540 _fast_method: ODEMethod::RK4,
541 _slow_method: ODEMethod::RK4,
542 },
543 macro_step: 0.1,
544 rtol: 1e-6,
545 atol: 1e-9,
546 max_steps: 100,
547 timescale_ratio: Some(1000.0),
548 };
549
550 let mut solver = MultirateSolver::new(options);
551 let y0 = Array1::from_vec(vec![1.0, 0.0, 0.1, 0.0]);
552
553 let result = solver.solve(system, [0.0, 0.5], y0).unwrap();
554
555 assert!(result.t.len() > 1);
556 assert!(result.n_steps > 0);
557 }
558
559 #[test]
560 fn test_extrapolated_multirate_method() {
561 let system = FastSlowOscillator {
562 omega_fast: 15.0,
563 epsilon: 0.2,
564 coupling: 0.03,
565 };
566
567 let options = MultirateOptions {
568 method: MultirateMethod::Extrapolated {
569 base_ratio: 5,
570 levels: 2,
571 },
572 macro_step: 0.02,
573 rtol: 1e-8,
574 atol: 1e-11,
575 max_steps: 500,
576 timescale_ratio: Some(75.0),
577 };
578
579 let mut solver = MultirateSolver::new(options);
580 let y0 = Array1::from_vec(vec![0.5, 0.0, 0.2, 0.1]);
581
582 let result = solver.solve(system, [0.0, 0.2], y0).unwrap();
583
584 assert!(result.t.len() > 1);
585 assert!(result.n_steps > 0);
586 }
587}