1use crate::common::IntegrateFloat;
8use crate::error::{IntegrateError, IntegrateResult};
9use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11use std::fmt::Debug;
12
13#[derive(Debug, Clone)]
15pub struct ExtrapolationOptions<F: IntegrateFloat> {
16 pub max_order: usize,
18 pub min_order: usize,
20 pub base_method: ExtrapolationBaseMethod,
22 pub extrapolation_tol: F,
24 pub safety_factor: F,
26 pub max_increase_factor: F,
28 pub max_decrease_factor: F,
30}
31
32impl<F: IntegrateFloat> Default for ExtrapolationOptions<F> {
33 fn default() -> Self {
34 Self {
35 max_order: 10,
36 min_order: 3,
37 base_method: ExtrapolationBaseMethod::ModifiedMidpoint,
38 extrapolation_tol: F::from_f64(1e-12).unwrap(),
39 safety_factor: F::from_f64(0.9).unwrap(),
40 max_increase_factor: F::from_f64(1.5).unwrap(),
41 max_decrease_factor: F::from_f64(0.5).unwrap(),
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy)]
48pub enum ExtrapolationBaseMethod {
49 ModifiedMidpoint,
51 Euler,
53 RungeKutta4,
55}
56
57#[derive(Debug, Clone)]
59pub struct ExtrapolationResult<F: IntegrateFloat> {
60 pub y: Array1<F>,
62 pub error_estimate: F,
64 pub table: Array2<F>,
66 pub n_substeps: usize,
68 pub final_order: usize,
70 pub converged: bool,
72}
73
74#[allow(dead_code)]
91pub fn gragg_bulirsch_stoer_method<F, Func>(
92 f: Func,
93 t_span: [F; 2],
94 y0: Array1<F>,
95 opts: ODEOptions<F>,
96 ext_opts: Option<ExtrapolationOptions<F>>,
97) -> IntegrateResult<ODEResult<F>>
98where
99 F: IntegrateFloat,
100 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
101{
102 let [t_start, t_end] = t_span;
103 let ext_options = ext_opts.unwrap_or_default();
104
105 let mut h = opts.h0.unwrap_or_else(|| {
107 let _span = t_end - t_start;
108 _span / F::from_usize(100).unwrap()
109 });
110
111 let min_step = opts.min_step.unwrap_or_else(|| {
112 let _span = t_end - t_start;
113 _span / F::from_usize(1_000_000).unwrap()
114 });
115
116 let max_step = opts.max_step.unwrap_or_else(|| {
117 let _span = t_end - t_start;
118 _span / F::from_usize(10).unwrap()
119 });
120
121 let mut t_values = vec![t_start];
123 let mut y_values = vec![y0.clone()];
124
125 let mut t = t_start;
126 let mut y = y0;
127 let mut steps = 0;
128 let mut func_evals = 0;
129 let mut rejected_steps = 0;
130
131 while t < t_end {
132 if t + h > t_end {
134 h = t_end - t;
135 }
136
137 let result = extrapolation_step(&f, t, &y, h, &ext_options)?;
139 func_evals += result.n_substeps * (result.n_substeps + 1); let error_estimate = result.error_estimate;
143 let tolerance =
144 opts.atol + opts.rtol * y.iter().map(|&x| x.abs()).fold(F::zero(), |a, b| a.max(b));
145
146 if error_estimate <= tolerance {
147 t += h;
149 y = result.y;
150 steps += 1;
151
152 t_values.push(t);
154 y_values.push(y.clone());
155
156 if result.converged && result.final_order >= ext_options.min_order {
158 h *= ext_options.max_increase_factor.min(
159 (tolerance / error_estimate.max(F::from_f64(1e-14).unwrap()))
160 .powf(F::one() / F::from_usize(result.final_order + 1).unwrap())
161 * ext_options.safety_factor,
162 );
163 }
164 } else {
165 rejected_steps += 1;
167 h *= ext_options.max_decrease_factor.max(
168 (tolerance / error_estimate)
169 .powf(F::one() / F::from_usize(result.final_order + 1).unwrap())
170 * ext_options.safety_factor,
171 );
172 }
173
174 if h < min_step {
176 return Err(IntegrateError::StepSizeTooSmall(
177 "Step size became too small in extrapolation method".to_string(),
178 ));
179 }
180
181 h = h.min(max_step);
183
184 if steps > 100000 {
186 return Err(IntegrateError::ComputationError(
187 "Maximum number of steps exceeded in extrapolation method".to_string(),
188 ));
189 }
190 }
191
192 Ok(ODEResult {
193 t: t_values,
194 y: y_values,
195 success: true,
196 message: Some("Integration completed successfully".to_string()),
197 n_eval: func_evals,
198 n_steps: steps,
199 n_accepted: steps,
200 n_rejected: rejected_steps,
201 n_lu: 0,
202 n_jac: 0,
203 method: ODEMethod::RK45, })
205}
206
207#[allow(dead_code)]
209fn extrapolation_step<F, Func>(
210 f: &Func,
211 t: F,
212 y: &Array1<F>,
213 h: F,
214 options: &ExtrapolationOptions<F>,
215) -> IntegrateResult<ExtrapolationResult<F>>
216where
217 F: IntegrateFloat,
218 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
219{
220 let _n_dim = y.len();
221 let max_order = options.max_order;
222
223 let step_sequence: Vec<usize> = (1..=max_order).map(|i| 2 * i).collect();
225
226 let mut table = Array2::zeros((max_order, max_order));
228 let mut y_table = Vec::new();
229
230 let mut converged = false;
231 let mut final_order = 0;
232 let mut error_estimate = F::infinity();
233
234 for (i, &n_steps) in step_sequence.iter().enumerate() {
236 if i >= max_order {
237 break;
238 }
239
240 let h_sub = h / F::from_usize(n_steps).unwrap();
242 let y_approx = match options.base_method {
243 ExtrapolationBaseMethod::ModifiedMidpoint => {
244 modified_midpoint_sequence(f, t, y, h_sub, n_steps)?
245 }
246 ExtrapolationBaseMethod::Euler => euler_sequence(f, t, y, h_sub, n_steps)?,
247 ExtrapolationBaseMethod::RungeKutta4 => rk4_sequence(f, t, y, h_sub, n_steps)?,
248 };
249
250 y_table.push(y_approx.clone());
251
252 let norm = y_approx
254 .iter()
255 .map(|&x| x * x)
256 .fold(F::zero(), |a, b| a + b)
257 .sqrt();
258 table[[i, 0]] = norm;
259
260 for j in 1..=i {
262 if j >= max_order {
263 break;
264 }
265
266 let ratio = F::from_usize(step_sequence[i]).unwrap()
269 / F::from_usize(step_sequence[i - 1]).unwrap();
270 let denominator = ratio.powf(F::from_usize(2 * j).unwrap()) - F::one();
271
272 if denominator.abs() > F::from_f64(1e-14).unwrap() {
273 table[[i, j]] =
274 table[[i, j - 1]] + (table[[i, j - 1]] - table[[i - 1, j - 1]]) / denominator;
275 } else {
276 table[[i, j]] = table[[i, j - 1]];
277 }
278 }
279
280 if i >= options.min_order - 1 {
282 let current_order = i;
283 if current_order > 0 {
284 let current_est = table[[current_order, current_order]];
285 let prev_est = table[[current_order - 1, current_order - 1]];
286 error_estimate = (current_est - prev_est).abs();
287
288 if error_estimate <= options.extrapolation_tol * current_est.abs() {
289 converged = true;
290 final_order = current_order + 1;
291 break;
292 }
293 }
294 }
295
296 final_order = i + 1;
297 }
298
299 let final_y = if final_order > 0 && !y_table.is_empty() {
301 y_table[final_order - 1].clone()
303 } else {
304 y.clone()
305 };
306
307 Ok(ExtrapolationResult {
308 y: final_y,
309 error_estimate,
310 table,
311 n_substeps: step_sequence
312 .get(final_order.saturating_sub(1))
313 .copied()
314 .unwrap_or(2),
315 final_order,
316 converged,
317 })
318}
319
320#[allow(dead_code)]
322fn modified_midpoint_sequence<F, Func>(
323 f: &Func,
324 t0: F,
325 y0: &Array1<F>,
326 h_sub: F,
327 n_steps: usize,
328) -> IntegrateResult<Array1<F>>
329where
330 F: IntegrateFloat,
331 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
332{
333 if n_steps == 0 {
334 return Ok(y0.clone());
335 }
336
337 let mut y = y0.clone();
338 let mut y_prev = y0.clone();
339 let mut t = t0;
340
341 if n_steps >= 1 {
343 let dy = f(t, y.view());
344 let y_next = &y + &dy * h_sub;
345 y_prev = y.clone();
346 y = y_next;
347 t += h_sub;
348 }
349
350 for _ in 1..n_steps {
352 let dy = f(t, y.view());
353 let y_next = &y_prev + &dy * (F::from_f64(2.0).unwrap() * h_sub);
354 y_prev = y.clone();
355 y = y_next;
356 t += h_sub;
357 }
358
359 if n_steps > 1 {
361 let dy = f(t, y.view());
362 y = (&y + &y_prev + &dy * h_sub) * F::from_f64(0.5).unwrap();
363 }
364
365 Ok(y)
366}
367
368#[allow(dead_code)]
370fn euler_sequence<F, Func>(
371 f: &Func,
372 t0: F,
373 y0: &Array1<F>,
374 h_sub: F,
375 n_steps: usize,
376) -> IntegrateResult<Array1<F>>
377where
378 F: IntegrateFloat,
379 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
380{
381 let mut y = y0.clone();
382 let mut t = t0;
383
384 for _ in 0..n_steps {
385 let dy = f(t, y.view());
386 y = &y + &dy * h_sub;
387 t += h_sub;
388 }
389
390 Ok(y)
391}
392
393#[allow(dead_code)]
395fn rk4_sequence<F, Func>(
396 f: &Func,
397 t0: F,
398 y0: &Array1<F>,
399 h_sub: F,
400 n_steps: usize,
401) -> IntegrateResult<Array1<F>>
402where
403 F: IntegrateFloat,
404 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
405{
406 let mut y = y0.clone();
407 let mut t = t0;
408 let h_half = h_sub * F::from_f64(0.5).unwrap();
409 let h_sixth = h_sub / F::from_f64(6.0).unwrap();
410
411 for _ in 0..n_steps {
412 let k1 = f(t, y.view());
413 let k2 = f(t + h_half, (&y + &k1 * h_half).view());
414 let k3 = f(t + h_half, (&y + &k2 * h_half).view());
415 let k4 = f(t + h_sub, (&y + &k3 * h_sub).view());
416
417 y = &y
418 + (&k1 + &k2 * F::from_f64(2.0).unwrap() + &k3 * F::from_f64(2.0).unwrap() + &k4)
419 * h_sixth;
420 t += h_sub;
421 }
422
423 Ok(y)
424}
425
426#[allow(dead_code)]
431pub fn richardson_extrapolation_step<F, Func, Method>(
432 method: Method,
433 f: &Func,
434 t: F,
435 y: &Array1<F>,
436 h: F,
437) -> IntegrateResult<(Array1<F>, F)>
438where
439 F: IntegrateFloat,
440 Func: Fn(F, ArrayView1<F>) -> Array1<F> + ?Sized,
441 Method: Fn(&Func, F, &Array1<F>, F) -> IntegrateResult<Array1<F>>,
442{
443 let y1 = method(f, t, y, h)?;
445
446 let h_half = h * F::from_f64(0.5).unwrap();
448 let y_mid = method(f, t, y, h_half)?;
449 let y2 = method(f, t + h_half, &y_mid, h_half)?;
450
451 let y_extrapolated = (&y2 * F::from_f64(4.0).unwrap() - &y1) / F::from_f64(3.0).unwrap();
454
455 let error_estimate = (&y2 - &y1)
457 .iter()
458 .map(|&x| x.abs())
459 .fold(F::zero(), |a, b| a.max(b))
460 / F::from_f64(3.0).unwrap();
461
462 Ok((y_extrapolated, error_estimate))
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use approx::assert_relative_eq;
469
470 #[test]
471 fn test_modified_midpoint_sequence() {
472 let f = |_t: f64, y: ArrayView1<f64>| -y.to_owned();
474 let y0 = Array1::from_vec(vec![1.0]);
475 let h = 0.1;
476 let n_steps = 10;
477
478 let result = modified_midpoint_sequence(&f, 0.0, &y0, h / n_steps as f64, n_steps).unwrap();
479 let exact = (-h).exp();
480
481 assert_relative_eq!(result[0], exact, epsilon = 1e-3);
483 }
484
485 #[test]
486 fn test_richardson_extrapolation() {
487 let y0 = Array1::from_vec(vec![1.0]);
490 let h = 0.1;
491
492 let f = |_t: f64, y: ArrayView1<f64>| -y.to_owned();
494 let result =
495 gragg_bulirsch_stoer_method(f, [0.0, h], y0.clone(), ODEOptions::default(), None)
496 .unwrap();
497
498 let exact = (-h).exp();
499 let final_value = result.y.last().unwrap()[0];
500
501 assert!(result.success);
503 assert_relative_eq!(final_value, exact, epsilon = 1e-6);
504 }
505
506 #[test]
507 fn test_extrapolation_options_default() {
508 let opts: ExtrapolationOptions<f64> = Default::default();
509 assert_eq!(opts.max_order, 10);
510 assert_eq!(opts.min_order, 3);
511 assert_eq!(opts.safety_factor, 0.9);
512 }
513}