1use crate::error::{IntegrateError, IntegrateResult};
9use crate::ode::types::{ODEMethod, ODEOptions, ODEResult};
10use crate::IntegrateFloat;
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
12use std::fmt::Debug;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum LsodaMethodType {
17 Adams,
19 Bdf,
21}
22
23struct LsodaState<F: IntegrateFloat> {
25 t: F,
27 y: Array1<F>,
29 dy: Array1<F>,
31 h: F,
33 t_history: Vec<F>,
35 y_history: Vec<Array1<F>>,
37 dy_history: Vec<Array1<F>>,
39 method_type: LsodaMethodType,
41 order: usize,
43 jacobian: Option<Array2<F>>,
45 jacobian_age: usize,
47 stiff_to_nonstiff_switches: usize,
49 nonstiff_to_stiff_switches: usize,
50 steps_since_switch: usize,
52 func_evals: usize,
54 n_lu: usize,
56 n_jac: usize,
58 steps: usize,
60 accepted_steps: usize,
62 rejected_steps: usize,
64 stiffness_detected_count: usize,
66 non_stiffness_detected_count: usize,
67 recently_switched: bool,
69 tol_scale: Array1<F>,
71}
72
73impl<F: IntegrateFloat> LsodaState<F> {
74 fn new(t: F, y: Array1<F>, dy: Array1<F>, h: F, rtol: F, atol: F) -> Self {
76 let n_dim = y.len();
77
78 let mut tol_scale = Array1::<F>::zeros(n_dim);
80 for i in 0..n_dim {
81 tol_scale[i] = atol + rtol * y[i].abs();
82 }
83
84 LsodaState {
85 t,
86 y: y.clone(),
87 dy: dy.clone(),
88 h,
89 t_history: vec![t],
90 y_history: vec![y],
91 dy_history: vec![dy],
92 method_type: LsodaMethodType::Adams, order: 1, jacobian: None,
95 jacobian_age: 0,
96 stiff_to_nonstiff_switches: 0,
97 nonstiff_to_stiff_switches: 0,
98 steps_since_switch: 0,
99 func_evals: 0,
100 n_lu: 0,
101 n_jac: 0,
102 steps: 0,
103 accepted_steps: 0,
104 rejected_steps: 0,
105 stiffness_detected_count: 0,
106 non_stiffness_detected_count: 0,
107 recently_switched: false,
108 tol_scale,
109 }
110 }
111
112 fn update_tol_scale(&mut self, rtol: F, atol: F) {
114 for i in 0..self.y.len() {
115 self.tol_scale[i] = atol + rtol * self.y[i].abs();
116 }
117 }
118
119 fn add_to_history(&mut self) {
121 self.t_history.push(self.t);
122 self.y_history.push(self.y.clone());
123 self.dy_history.push(self.dy.clone());
124
125 let max_history = match self.method_type {
127 LsodaMethodType::Adams => 12, LsodaMethodType::Bdf => 5, };
130
131 if self.t_history.len() > max_history {
132 self.t_history.remove(0);
133 self.y_history.remove(0);
134 self.dy_history.remove(0);
135 }
136 }
137
138 fn switch_method(&mut self, _newmethod: LsodaMethodType) {
140 if self.method_type == LsodaMethodType::Adams && _newmethod == LsodaMethodType::Bdf {
142 self.nonstiff_to_stiff_switches += 1;
143
144 self.order = 1;
146 self.jacobian = None;
147 self.jacobian_age = 0;
148 } else if self.method_type == LsodaMethodType::Bdf && _newmethod == LsodaMethodType::Adams {
149 self.stiff_to_nonstiff_switches += 1;
150
151 self.order = 1;
153
154 if self.rejected_steps > 2 {
156 let half = F::from_f64(0.5)
157 .ok_or_else(|| {
158 IntegrateError::ComputationError(
159 "Failed to convert constant 0.5 to float type".to_string(),
160 )
161 })
162 .unwrap_or_else(|_| F::from(0.5).unwrap()); self.h *= half;
164 }
165 }
166
167 self.steps_since_switch = 0;
169 self.recently_switched = true;
170
171 self.method_type = _newmethod;
173 }
174}
175
176struct StiffnessDetector<F: IntegrateFloat> {
178 min_steps_before_switch: usize,
180 stiffness_threshold: usize,
182 non_stiffness_threshold: usize,
184 #[allow(dead_code)]
186 step_size_ratio_threshold: F,
187}
188
189impl<F: IntegrateFloat> StiffnessDetector<F> {
190 fn new() -> Self {
192 StiffnessDetector {
193 min_steps_before_switch: 5,
194 stiffness_threshold: 3,
195 non_stiffness_threshold: 5,
196 step_size_ratio_threshold: F::from_f64(0.1)
197 .ok_or_else(|| {
198 IntegrateError::ComputationError(
199 "Failed to convert constant 0.1 to float type".to_string(),
200 )
201 })
202 .unwrap_or_else(|_| F::from(0.1).unwrap()), }
204 }
205
206 fn is_stiff(&self, state: &LsodaState<F>) -> bool {
208 if state.steps_since_switch < self.min_steps_before_switch {
210 return false;
211 }
212
213 if state.method_type == LsodaMethodType::Bdf {
215 return state.non_stiffness_detected_count < self.non_stiffness_threshold;
216 }
217
218 state.stiffness_detected_count >= self.stiffness_threshold
220 }
221}
222
223#[allow(dead_code)]
251pub fn lsoda_method<F, Func>(
252 f: Func,
253 t_span: [F; 2],
254 y0: Array1<F>,
255 opts: ODEOptions<F>,
256) -> IntegrateResult<ODEResult<F>>
257where
258 F: IntegrateFloat,
259 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
260{
261 let [t_start, t_end] = t_span;
263 let _n_dim = y0.len();
264
265 let dy0 = f(t_start, y0.view());
267 let mut func_evals = 1;
268
269 let h0 = opts.h0.unwrap_or_else(|| {
271 let _span = t_end - t_start;
273 let hundred = F::from_usize(100).unwrap_or_else(|| F::from(100).unwrap());
274 let tenth = F::from_f64(0.1).unwrap_or_else(|| F::from(0.1).unwrap());
275 _span / hundred * tenth });
277
278 let min_step = opts.min_step.unwrap_or_else(|| {
280 let _span = t_end - t_start;
281 let epsilon = F::from_f64(1e-10).unwrap_or_else(|| F::from(1e-10).unwrap());
282 _span * epsilon });
284
285 let max_step = opts.max_step.unwrap_or_else(|| {
286 t_end - t_start });
288
289 let mut state = LsodaState::new(t_start, y0.clone(), dy0, h0, opts.rtol, opts.atol);
291 let stiffness_detector = StiffnessDetector::new();
292
293 let mut t_values = vec![t_start];
295 let mut y_values = vec![y0.clone()];
296
297 while state.t < t_end && state.steps < opts.max_steps {
299 if state.recently_switched
301 && state.steps_since_switch >= stiffness_detector.min_steps_before_switch
302 {
303 state.recently_switched = false;
304 }
305
306 if state.t + state.h > t_end {
308 state.h = t_end - state.t;
309 }
310
311 state.h = state.h.min(max_step).max(min_step);
313
314 let step_result = match state.method_type {
316 LsodaMethodType::Adams => adams_step(&mut state, &f, &opts, &mut func_evals),
317 LsodaMethodType::Bdf => bdf_step(&mut state, &f, &opts, &mut func_evals),
318 };
319
320 state.steps += 1;
321 state.steps_since_switch += 1;
322
323 match step_result {
324 Ok(accepted) => {
325 if accepted {
326 state.add_to_history();
330 t_values.push(state.t);
331 y_values.push(state.y.clone());
332
333 state.accepted_steps += 1;
334
335 if !state.recently_switched {
337 let is_stiff = stiffness_detector.is_stiff(&state);
338
339 if state.method_type == LsodaMethodType::Adams && is_stiff {
340 state.switch_method(LsodaMethodType::Bdf);
342 } else if state.method_type == LsodaMethodType::Bdf && !is_stiff {
343 state.switch_method(LsodaMethodType::Adams);
345 }
346 }
347
348 state.update_tol_scale(opts.rtol, opts.atol);
350
351 if state.method_type == LsodaMethodType::Bdf && state.jacobian.is_some() {
353 state.jacobian_age += 1;
354 }
355 } else {
356 state.rejected_steps += 1;
358 }
359 }
360 Err(e) => {
361 match &e {
363 IntegrateError::ConvergenceError(msg) if msg.contains("stiff") => {
364 if state.method_type == LsodaMethodType::Adams {
365 state.stiffness_detected_count += 1;
367 state.switch_method(LsodaMethodType::Bdf);
368
369 let half = F::from_f64(0.5).unwrap_or_else(|| F::from(0.5).unwrap());
371 state.h *= half;
372 if state.h < min_step {
373 return Err(IntegrateError::ConvergenceError(
374 "Step size too small after method switch".to_string(),
375 ));
376 }
377 } else {
378 return Err(e);
380 }
381 }
382 IntegrateError::ConvergenceError(msg) if msg.contains("non-stiff") => {
383 if state.method_type == LsodaMethodType::Bdf {
384 state.non_stiffness_detected_count += 1;
386 state.switch_method(LsodaMethodType::Adams);
387
388 let half = F::from_f64(0.5).unwrap_or_else(|| F::from(0.5).unwrap());
390 state.h *= half;
391 if state.h < min_step {
392 return Err(IntegrateError::ConvergenceError(
393 "Step size too small after method switch".to_string(),
394 ));
395 }
396 } else {
397 return Err(e);
399 }
400 }
401 _ => return Err(e), }
403 }
404 }
405 }
406
407 let success = state.t >= t_end;
408 let message = if !success {
409 Some(format!(
410 "Maximum number of steps ({}) reached",
411 opts.max_steps
412 ))
413 } else {
414 None
415 };
416
417 Ok(ODEResult {
419 t: t_values,
420 y: y_values,
421 success,
422 message,
423 n_eval: func_evals,
424 n_steps: state.steps,
425 n_accepted: state.accepted_steps,
426 n_rejected: state.rejected_steps,
427 n_lu: state.n_lu,
428 n_jac: state.n_jac,
429 method: ODEMethod::LSODA,
430 })
431}
432
433#[allow(dead_code)]
435fn adams_step<F, Func>(
436 state: &mut LsodaState<F>,
437 f: &Func,
438 opts: &ODEOptions<F>,
439 func_evals: &mut usize,
440) -> IntegrateResult<bool>
441where
442 F: IntegrateFloat,
443 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
444{
445 let ab_coeffs: [Vec<F>; 12] = [
448 vec![F::one()],
450 vec![
452 F::from_f64(3.0 / 2.0).unwrap(),
453 F::from_f64(-1.0 / 2.0).unwrap(),
454 ],
455 vec![
457 F::from_f64(23.0 / 12.0).unwrap(),
458 F::from_f64(-16.0 / 12.0).unwrap(),
459 F::from_f64(5.0 / 12.0).unwrap(),
460 ],
461 vec![
463 F::from_f64(55.0 / 24.0).unwrap(),
464 F::from_f64(-59.0 / 24.0).unwrap(),
465 F::from_f64(37.0 / 24.0).unwrap(),
466 F::from_f64(-9.0 / 24.0).unwrap(),
467 ],
468 vec![
470 F::from_f64(1901.0 / 720.0).unwrap(),
471 F::from_f64(-2774.0 / 720.0).unwrap(),
472 F::from_f64(2616.0 / 720.0).unwrap(),
473 F::from_f64(-1274.0 / 720.0).unwrap(),
474 F::from_f64(251.0 / 720.0).unwrap(),
475 ],
476 vec![
478 F::from_f64(4277.0 / 1440.0).unwrap(),
479 F::from_f64(-7923.0 / 1440.0).unwrap(),
480 F::from_f64(9982.0 / 1440.0).unwrap(),
481 F::from_f64(-7298.0 / 1440.0).unwrap(),
482 F::from_f64(2877.0 / 1440.0).unwrap(),
483 F::from_f64(-475.0 / 1440.0).unwrap(),
484 ],
485 vec![
487 F::from_f64(198721.0 / 60480.0).unwrap(),
488 F::from_f64(-447288.0 / 60480.0).unwrap(),
489 F::from_f64(705549.0 / 60480.0).unwrap(),
490 F::from_f64(-688256.0 / 60480.0).unwrap(),
491 F::from_f64(407139.0 / 60480.0).unwrap(),
492 F::from_f64(-134472.0 / 60480.0).unwrap(),
493 F::from_f64(19087.0 / 60480.0).unwrap(),
494 ],
495 vec![
497 F::from_f64(434241.0 / 120960.0).unwrap(),
498 F::from_f64(-1152169.0 / 120960.0).unwrap(),
499 F::from_f64(2183877.0 / 120960.0).unwrap(),
500 F::from_f64(-2664477.0 / 120960.0).unwrap(),
501 F::from_f64(2102243.0 / 120960.0).unwrap(),
502 F::from_f64(-1041723.0 / 120960.0).unwrap(),
503 F::from_f64(295767.0 / 120960.0).unwrap(),
504 F::from_f64(-36799.0 / 120960.0).unwrap(),
505 ],
506 vec![
508 F::from_f64(14097247.0 / 3628800.0).unwrap(),
509 F::from_f64(-43125206.0 / 3628800.0).unwrap(),
510 F::from_f64(95476786.0 / 3628800.0).unwrap(),
511 F::from_f64(-139855262.0 / 3628800.0).unwrap(),
512 F::from_f64(137968480.0 / 3628800.0).unwrap(),
513 F::from_f64(-91172642.0 / 3628800.0).unwrap(),
514 F::from_f64(38833486.0 / 3628800.0).unwrap(),
515 F::from_f64(-9664106.0 / 3628800.0).unwrap(),
516 F::from_f64(1070017.0 / 3628800.0).unwrap(),
517 ],
518 vec![
520 F::from_f64(30277247.0 / 7257600.0).unwrap(),
521 F::from_f64(-104995189.0 / 7257600.0).unwrap(),
522 F::from_f64(265932680.0 / 7257600.0).unwrap(),
523 F::from_f64(-454661776.0 / 7257600.0).unwrap(),
524 F::from_f64(538363838.0 / 7257600.0).unwrap(),
525 F::from_f64(-444772162.0 / 7257600.0).unwrap(),
526 F::from_f64(252618224.0 / 7257600.0).unwrap(),
527 F::from_f64(-94307320.0 / 7257600.0).unwrap(),
528 F::from_f64(20884811.0 / 7257600.0).unwrap(),
529 F::from_f64(-2082753.0 / 7257600.0).unwrap(),
530 ],
531 vec![
533 F::from_f64(35256204767.0 / 7983360000.0).unwrap(),
534 F::from_f64(-134336876800.0 / 7983360000.0).unwrap(),
535 F::from_f64(385146025457.0 / 7983360000.0).unwrap(),
536 F::from_f64(-754734083733.0 / 7983360000.0).unwrap(),
537 F::from_f64(1045594573504.0 / 7983360000.0).unwrap(),
538 F::from_f64(-1029725952608.0 / 7983360000.0).unwrap(),
539 F::from_f64(717313887930.0 / 7983360000.0).unwrap(),
540 F::from_f64(-344156361067.0 / 7983360000.0).unwrap(),
541 F::from_f64(109301088672.0 / 7983360000.0).unwrap(),
542 F::from_f64(-21157613775.0 / 7983360000.0).unwrap(),
543 F::from_f64(1832380165.0 / 7983360000.0).unwrap(),
544 ],
545 vec![
547 F::from_f64(77737505967.0 / 16876492800.0).unwrap(),
548 F::from_f64(-328202700680.0 / 16876492800.0).unwrap(),
549 F::from_f64(1074851727475.0 / 16876492800.0).unwrap(),
550 F::from_f64(-2459572352768.0 / 16876492800.0).unwrap(),
551 F::from_f64(4013465151807.0 / 16876492800.0).unwrap(),
552 F::from_f64(-4774671405984.0 / 16876492800.0).unwrap(),
553 F::from_f64(4127030565077.0 / 16876492800.0).unwrap(),
554 F::from_f64(-2538584431976.0 / 16876492800.0).unwrap(),
555 F::from_f64(1077984741336.0 / 16876492800.0).unwrap(),
556 F::from_f64(-295501032385.0 / 16876492800.0).unwrap(),
557 F::from_f64(48902348238.0 / 16876492800.0).unwrap(),
558 F::from_f64(-3525779602.0 / 16876492800.0).unwrap(),
559 ],
560 ];
561
562 let am_coeffs: [Vec<F>; 12] = [
565 vec![F::one()],
567 vec![
569 F::from_f64(1.0 / 2.0).unwrap(),
570 F::from_f64(1.0 / 2.0).unwrap(),
571 ],
572 vec![
574 F::from_f64(5.0 / 12.0).unwrap(),
575 F::from_f64(8.0 / 12.0).unwrap(),
576 F::from_f64(-1.0 / 12.0).unwrap(),
577 ],
578 vec![
580 F::from_f64(9.0 / 24.0).unwrap(),
581 F::from_f64(19.0 / 24.0).unwrap(),
582 F::from_f64(-5.0 / 24.0).unwrap(),
583 F::from_f64(1.0 / 24.0).unwrap(),
584 ],
585 vec![
587 F::from_f64(251.0 / 720.0).unwrap(),
588 F::from_f64(646.0 / 720.0).unwrap(),
589 F::from_f64(-264.0 / 720.0).unwrap(),
590 F::from_f64(106.0 / 720.0).unwrap(),
591 F::from_f64(-19.0 / 720.0).unwrap(),
592 ],
593 vec![
596 F::from_f64(475.0 / 1440.0).unwrap(),
597 F::from_f64(1427.0 / 1440.0).unwrap(),
598 F::from_f64(-798.0 / 1440.0).unwrap(),
599 F::from_f64(482.0 / 1440.0).unwrap(),
600 F::from_f64(-173.0 / 1440.0).unwrap(),
601 F::from_f64(27.0 / 1440.0).unwrap(),
602 ],
603 vec![
605 F::from_f64(19087.0 / 60480.0).unwrap(),
606 F::from_f64(65112.0 / 60480.0).unwrap(),
607 F::from_f64(-46461.0 / 60480.0).unwrap(),
608 F::from_f64(37504.0 / 60480.0).unwrap(),
609 F::from_f64(-20211.0 / 60480.0).unwrap(),
610 F::from_f64(6312.0 / 60480.0).unwrap(),
611 F::from_f64(-863.0 / 60480.0).unwrap(),
612 ],
613 vec![
615 F::from_f64(36799.0 / 120960.0).unwrap(),
616 F::from_f64(139849.0 / 120960.0).unwrap(),
617 F::from_f64(-121797.0 / 120960.0).unwrap(),
618 F::from_f64(123133.0 / 120960.0).unwrap(),
619 F::from_f64(-88547.0 / 120960.0).unwrap(),
620 F::from_f64(41499.0 / 120960.0).unwrap(),
621 F::from_f64(-11351.0 / 120960.0).unwrap(),
622 F::from_f64(1375.0 / 120960.0).unwrap(),
623 ],
624 vec![
626 F::from_f64(1070017.0 / 3628800.0).unwrap(),
627 F::from_f64(4467094.0 / 3628800.0).unwrap(),
628 F::from_f64(-4604594.0 / 3628800.0).unwrap(),
629 F::from_f64(5595358.0 / 3628800.0).unwrap(),
630 F::from_f64(-5033120.0 / 3628800.0).unwrap(),
631 F::from_f64(3146338.0 / 3628800.0).unwrap(),
632 F::from_f64(-1291214.0 / 3628800.0).unwrap(),
633 F::from_f64(312874.0 / 3628800.0).unwrap(),
634 F::from_f64(-33953.0 / 3628800.0).unwrap(),
635 ],
636 vec![
638 F::from_f64(2082753.0 / 7257600.0).unwrap(),
639 F::from_f64(9449717.0 / 7257600.0).unwrap(),
640 F::from_f64(-11271304.0 / 7257600.0).unwrap(),
641 F::from_f64(16002320.0 / 7257600.0).unwrap(),
642 F::from_f64(-17283646.0 / 7257600.0).unwrap(),
643 F::from_f64(13510082.0 / 7257600.0).unwrap(),
644 F::from_f64(-7394032.0 / 7257600.0).unwrap(),
645 F::from_f64(2687864.0 / 7257600.0).unwrap(),
646 F::from_f64(-583435.0 / 7257600.0).unwrap(),
647 F::from_f64(57281.0 / 7257600.0).unwrap(),
648 ],
649 vec![
651 F::from_f64(1832380165.0 / 7983360000.0).unwrap(),
652 F::from_f64(8862145928.0 / 7983360000.0).unwrap(),
653 F::from_f64(-11901858253.0 / 7983360000.0).unwrap(),
654 F::from_f64(19151811844.0 / 7983360000.0).unwrap(),
655 F::from_f64(-23709112128.0 / 7983360000.0).unwrap(),
656 F::from_f64(22186204517.0 / 7983360000.0).unwrap(),
657 F::from_f64(-15364126130.0 / 7983360000.0).unwrap(),
658 F::from_f64(7503814963.0 / 7983360000.0).unwrap(),
659 F::from_f64(-2395311906.0 / 7983360000.0).unwrap(),
660 F::from_f64(467772723.0 / 7983360000.0).unwrap(),
661 F::from_f64(-41469557.0 / 7983360000.0).unwrap(),
662 ],
663 vec![
665 F::from_f64(3525779602.0 / 16876492800.0).unwrap(),
666 F::from_f64(17870808964.0 / 16876492800.0).unwrap(),
667 F::from_f64(-26564533485.0 / 16876492800.0).unwrap(),
668 F::from_f64(47566383032.0 / 16876492800.0).unwrap(),
669 F::from_f64(-66692205045.0 / 16876492800.0).unwrap(),
670 F::from_f64(72077402760.0 / 16876492800.0).unwrap(),
671 F::from_f64(-59658274307.0 / 16876492800.0).unwrap(),
672 F::from_f64(36174330240.0 / 16876492800.0).unwrap(),
673 F::from_f64(-15568150189.0 / 16876492800.0).unwrap(),
674 F::from_f64(4443502217.0 / 16876492800.0).unwrap(),
675 F::from_f64(-772653805.0 / 16876492800.0).unwrap(),
676 F::from_f64(62628216.0 / 16876492800.0).unwrap(),
677 ],
678 ];
679
680 let order = state.order.min(state.dy_history.len() + 1).min(12);
682
683 if order == 1 || state.dy_history.is_empty() {
685 let next_t = state.t + state.h;
687 let next_y = state.y.clone() + state.dy.clone() * state.h;
688
689 let next_dy = f(next_t, next_y.view());
691 *func_evals += 1;
692 state.func_evals += 1;
693
694 state.t = next_t;
696 state.y = next_y;
697 state.dy = next_dy;
698
699 if state.order < 2 {
701 state.order += 1;
702 }
703
704 return Ok(true);
705 }
706
707 let next_t = state.t + state.h;
709 let ab_coefs = &ab_coeffs[order - 1];
710
711 let mut ab_sum = state.dy.clone() * ab_coefs[0];
714
715 for (i, &coeff) in ab_coefs.iter().enumerate().take(order).skip(1) {
716 if i <= state.dy_history.len() {
717 let idx = state.dy_history.len() - i;
718 ab_sum = ab_sum + state.dy_history[idx].clone() * coeff;
719 }
720 }
721
722 let y_pred = state.y.clone() + ab_sum * state.h;
723
724 let dy_pred = f(next_t, y_pred.view());
726 *func_evals += 1;
727 state.func_evals += 1;
728
729 let am_coefs = &am_coeffs[order - 1];
731
732 let mut am_sum = dy_pred.clone() * am_coefs[0]; for (i, &coeff) in am_coefs.iter().enumerate().take(order).skip(1) {
737 if i == 1 {
738 am_sum = am_sum + state.dy.clone() * coeff;
740 } else if i - 1 < state.dy_history.len() {
741 let idx = state.dy_history.len() - (i - 1);
743 am_sum = am_sum + state.dy_history[idx].clone() * coeff;
744 }
745 }
746
747 let y_corr = state.y.clone() + am_sum * state.h;
748
749 let dy_corr = f(next_t, y_corr.view());
751 *func_evals += 1;
752 state.func_evals += 1;
753
754 let mut max_err = F::zero();
756 for i in 0..state.y.len() {
757 let err = (y_corr[i] - y_pred[i]).abs();
758 let scale = state.tol_scale[i];
759 max_err = max_err.max(err / scale);
760 }
761
762 let err_order = order + 1; let err_factor = if max_err > F::zero() {
765 F::from_f64(0.9).unwrap()
766 * (F::one() / max_err).powf(F::one() / F::from_usize(err_order).unwrap())
767 } else {
768 F::from_f64(5.0).unwrap() };
770
771 let safety = F::from_f64(0.9).unwrap();
773 let factor_max = F::from_f64(5.0).unwrap();
774 let factor_min = F::from_f64(0.2).unwrap();
775 let factor = safety * err_factor.min(factor_max).max(factor_min);
776
777 if max_err <= F::one() {
779 state.t = next_t;
783 state.y = y_corr;
784 state.dy = dy_corr;
785
786 state.h *= factor;
788
789 if order < 12 && max_err < opts.rtol && state.dy_history.len() >= order {
791 state.order = (state.order + 1).min(12);
792 } else if order > 1 && max_err > F::from_f64(0.5).unwrap() {
793 state.order = (state.order - 1).max(1);
794 }
795
796 if max_err < opts.rtol * F::from_f64(0.01).unwrap() {
798 state.non_stiffness_detected_count += 1;
799 }
800
801 Ok(true)
802 } else {
803 state.h *= factor;
807
808 if max_err > F::from_f64(10.0).unwrap() {
810 state.stiffness_detected_count += 1;
811
812 if state.stiffness_detected_count > 2 {
814 return Err(IntegrateError::ConvergenceError(
815 "Problem appears stiff - consider using Bdf method".to_string(),
816 ));
817 }
818 }
819
820 Ok(false)
821 }
822}
823
824#[allow(dead_code)]
826fn bdf_step<F, Func>(
827 state: &mut LsodaState<F>,
828 f: &Func,
829 opts: &ODEOptions<F>,
830 func_evals: &mut usize,
831) -> IntegrateResult<bool>
832where
833 F: IntegrateFloat,
834 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
835{
836 let bdf_coefs: [Vec<F>; 5] = [
838 vec![F::one(), F::from_f64(-1.0).unwrap()],
840 vec![
842 F::from_f64(3.0 / 2.0).unwrap(),
843 F::from_f64(-2.0).unwrap(),
844 F::from_f64(1.0 / 2.0).unwrap(),
845 ],
846 vec![
848 F::from_f64(11.0 / 6.0).unwrap(),
849 F::from_f64(-3.0).unwrap(),
850 F::from_f64(3.0 / 2.0).unwrap(),
851 F::from_f64(-1.0 / 3.0).unwrap(),
852 ],
853 vec![
855 F::from_f64(25.0 / 12.0).unwrap(),
856 F::from_f64(-4.0).unwrap(),
857 F::from_f64(3.0).unwrap(),
858 F::from_f64(-4.0 / 3.0).unwrap(),
859 F::from_f64(1.0 / 4.0).unwrap(),
860 ],
861 vec![
863 F::from_f64(137.0 / 60.0).unwrap(),
864 F::from_f64(-5.0).unwrap(),
865 F::from_f64(5.0).unwrap(),
866 F::from_f64(-10.0 / 3.0).unwrap(),
867 F::from_f64(5.0 / 4.0).unwrap(),
868 F::from_f64(-1.0 / 5.0).unwrap(),
869 ],
870 ];
871
872 let order = state.order.min(state.y_history.len()).min(5);
874
875 if order == 1 || state.y_history.is_empty() {
877 let next_t = state.t + state.h;
879
880 let y_pred = state.y.clone();
882
883 let max_newton_iters = 10;
885 let newton_tol = F::from_f64(1e-8).unwrap();
886 let mut y_next = y_pred.clone();
887 let mut converged = false;
888 let mut iter_count = 0;
889
890 while iter_count < max_newton_iters {
891 let f_eval = f(next_t, y_next.view());
893 *func_evals += 1;
894 state.func_evals += 1;
895
896 let residual = y_next.clone() - state.y.clone() - f_eval.clone() * state.h;
898
899 let mut max_res = F::zero();
901 for i in 0..y_next.len() {
902 let scale = state.tol_scale[i];
903 max_res = max_res.max(residual[i].abs() / scale);
904 }
905
906 if max_res <= newton_tol {
907 converged = true;
908 break;
909 }
910
911 let eps = F::from_f64(1e-8).unwrap();
913 let n_dim = y_next.len();
914 let mut jacobian = Array2::<F>::zeros((n_dim, n_dim));
915
916 let compute_new_jacobian =
918 state.jacobian.is_none() || state.jacobian_age > 20 || iter_count == 0;
919
920 if compute_new_jacobian {
921 state.n_jac += 1;
922
923 for i in 0..n_dim {
924 let mut y_perturbed = y_next.clone();
925 y_perturbed[i] += eps;
926
927 let f_perturbed = f(next_t, y_perturbed.view());
928 *func_evals += 1;
929 state.func_evals += 1;
930
931 for j in 0..n_dim {
932 let df_dy = (f_perturbed[j] - f_eval[j]) / eps;
934
935 jacobian[[j, i]] = if i == j {
937 F::one() - state.h * df_dy
938 } else {
939 -state.h * df_dy
940 };
941 }
942 }
943
944 state.jacobian = Some(jacobian.clone());
946 state.jacobian_age = 0;
947 } else {
948 jacobian = state.jacobian.clone().unwrap();
950 }
951
952 let mut aug = Array2::<F>::zeros((n_dim, n_dim + 1));
954 for i in 0..n_dim {
955 for j in 0..n_dim {
956 aug[[i, j]] = jacobian[[i, j]];
957 }
958 aug[[i, n_dim]] = residual[i];
959 }
960
961 state.n_lu += 1;
962
963 for i in 0..n_dim {
965 let mut max_idx = i;
967 let mut max_val = aug[[i, i]].abs();
968
969 for j in i + 1..n_dim {
970 if aug[[j, i]].abs() > max_val {
971 max_idx = j;
972 max_val = aug[[j, i]].abs();
973 }
974 }
975
976 if max_val < F::from_f64(1e-10).unwrap() {
978 state.h *= F::from_f64(0.5).unwrap();
980 return Ok(false);
981 }
982
983 if max_idx != i {
985 for j in 0..n_dim + 1 {
986 let temp = aug[[i, j]];
987 aug[[i, j]] = aug[[max_idx, j]];
988 aug[[max_idx, j]] = temp;
989 }
990 }
991
992 for j in i + 1..n_dim {
994 let factor = aug[[j, i]] / aug[[i, i]];
995 for k in i..n_dim + 1 {
996 aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
997 }
998 }
999 }
1000
1001 let mut delta_y = Array1::<F>::zeros(n_dim);
1003 for i in (0..n_dim).rev() {
1004 let mut sum = aug[[i, n_dim]];
1005 for j in i + 1..n_dim {
1006 sum -= aug[[i, j]] * delta_y[j];
1007 }
1008 delta_y[i] = sum / aug[[i, i]];
1009 }
1010
1011 for i in 0..n_dim {
1013 y_next[i] -= delta_y[i];
1014 }
1015
1016 iter_count += 1;
1017 }
1018
1019 if !converged {
1020 state.h *= F::from_f64(0.5).unwrap();
1022
1023 if state.h < opts.min_step.unwrap_or(F::from_f64(1e-10).unwrap()) {
1026 state.non_stiffness_detected_count += 1;
1028
1029 return Err(IntegrateError::ConvergenceError(
1030 "BDF1 failed to converge - problem might be non-stiff".to_string(),
1031 ));
1032 }
1033
1034 return Ok(false);
1035 }
1036
1037 let next_dy = f(next_t, y_next.view());
1041 *func_evals += 1;
1042 state.func_evals += 1;
1043
1044 state.t = next_t;
1046 state.y = y_next;
1047 state.dy = next_dy;
1048
1049 if state.order < 2 {
1051 state.order += 1;
1052 }
1053
1054 return Ok(true);
1055 }
1056
1057 let coeffs = &bdf_coefs[order - 1];
1061
1062 let next_t = state.t + state.h;
1064
1065 let mut y_pred = state.y.clone();
1067
1068 if order > 1 && !state.y_history.is_empty() {
1070 let y_prev = &state.y_history[state.y_history.len() - 1];
1071
1072 let dt_ratio = state.h / (state.t - state.t_history[state.t_history.len() - 1]);
1074 y_pred = state.y.clone() + (state.y.clone() - y_prev) * dt_ratio;
1075 }
1076
1077 let max_newton_iters = 10;
1079 let newton_tol = F::from_f64(1e-8).unwrap();
1080 let mut y_next = y_pred.clone();
1081 let mut converged = false;
1082 let mut iter_count = 0;
1083
1084 while iter_count < max_newton_iters {
1085 let f_eval = f(next_t, y_next.view());
1087 *func_evals += 1;
1088 state.func_evals += 1;
1089
1090 let mut residual = y_next.clone() * coeffs[0];
1092
1093 residual = residual - state.y.clone() * coeffs[1];
1095
1096 for (j, &coeff) in coeffs.iter().enumerate().skip(2) {
1097 if j - 1 < state.y_history.len() {
1098 let idx = state.y_history.len() - (j - 1);
1099 residual = residual - state.y_history[idx].clone() * coeff;
1100 }
1101 }
1102
1103 residual = residual - f_eval.clone() * state.h;
1105
1106 let mut max_res = F::zero();
1108 for i in 0..y_next.len() {
1109 let scale = state.tol_scale[i];
1110 max_res = max_res.max(residual[i].abs() / scale);
1111 }
1112
1113 if max_res <= newton_tol {
1114 converged = true;
1115 break;
1116 }
1117
1118 let eps = F::from_f64(1e-8).unwrap();
1120 let n_dim = y_next.len();
1121 let mut jacobian = Array2::<F>::zeros((n_dim, n_dim));
1122
1123 let compute_new_jacobian =
1125 state.jacobian.is_none() || state.jacobian_age > 20 || iter_count == 0;
1126
1127 if compute_new_jacobian {
1128 state.n_jac += 1;
1129
1130 for i in 0..n_dim {
1131 let mut y_perturbed = y_next.clone();
1132 y_perturbed[i] += eps;
1133
1134 let f_perturbed = f(next_t, y_perturbed.view());
1135 *func_evals += 1;
1136 state.func_evals += 1;
1137
1138 for j in 0..n_dim {
1139 let df_dy = (f_perturbed[j] - f_eval[j]) / eps;
1141
1142 jacobian[[j, i]] = if i == j {
1144 coeffs[0] - state.h * df_dy
1145 } else {
1146 -state.h * df_dy
1147 };
1148 }
1149 }
1150
1151 state.jacobian = Some(jacobian.clone());
1153 state.jacobian_age = 0;
1154 } else {
1155 jacobian = state.jacobian.clone().unwrap();
1157 }
1158
1159 state.n_lu += 1;
1161
1162 let mut aug = Array2::<F>::zeros((n_dim, n_dim + 1));
1163 for i in 0..n_dim {
1164 for j in 0..n_dim {
1165 aug[[i, j]] = jacobian[[i, j]];
1166 }
1167 aug[[i, n_dim]] = residual[i];
1168 }
1169
1170 for i in 0..n_dim {
1172 let mut max_idx = i;
1174 let mut max_val = aug[[i, i]].abs();
1175
1176 for j in i + 1..n_dim {
1177 if aug[[j, i]].abs() > max_val {
1178 max_idx = j;
1179 max_val = aug[[j, i]].abs();
1180 }
1181 }
1182
1183 if max_val < F::from_f64(1e-10).unwrap() {
1185 state.h *= F::from_f64(0.5).unwrap();
1187 return Ok(false);
1188 }
1189
1190 if max_idx != i {
1192 for j in 0..n_dim + 1 {
1193 let temp = aug[[i, j]];
1194 aug[[i, j]] = aug[[max_idx, j]];
1195 aug[[max_idx, j]] = temp;
1196 }
1197 }
1198
1199 for j in i + 1..n_dim {
1201 let factor = aug[[j, i]] / aug[[i, i]];
1202 for k in i..n_dim + 1 {
1203 aug[[j, k]] = aug[[j, k]] - factor * aug[[i, k]];
1204 }
1205 }
1206 }
1207
1208 let mut delta_y = Array1::<F>::zeros(n_dim);
1210 for i in (0..n_dim).rev() {
1211 let mut sum = aug[[i, n_dim]];
1212 for j in i + 1..n_dim {
1213 sum -= aug[[i, j]] * delta_y[j];
1214 }
1215 delta_y[i] = sum / aug[[i, i]];
1216 }
1217
1218 for i in 0..n_dim {
1220 y_next[i] -= delta_y[i];
1221 }
1222
1223 iter_count += 1;
1224 }
1225
1226 if !converged {
1227 state.h *= F::from_f64(0.5).unwrap();
1229
1230 if iter_count >= max_newton_iters - 1 {
1232 state.non_stiffness_detected_count += 1;
1233 }
1234
1235 if state.h < opts.min_step.unwrap_or(F::from_f64(1e-10).unwrap()) {
1237 return Err(IntegrateError::ConvergenceError(
1238 "Bdf failed to converge - problem might be non-stiff".to_string(),
1239 ));
1240 }
1241
1242 return Ok(false);
1243 }
1244
1245 let next_dy = f(next_t, y_next.view());
1249 *func_evals += 1;
1250 state.func_evals += 1;
1251
1252 let lower_order = (order - 1).max(1);
1257 let lower_coeffs = &bdf_coefs[lower_order - 1];
1258
1259 let mut y_lower = Array1::<F>::zeros(y_next.len());
1261
1262 let mut rhs = Array1::<F>::zeros(y_next.len());
1264 for (j, &coeff) in lower_coeffs.iter().enumerate().skip(1).take(lower_order) {
1265 if j == 1 {
1266 rhs = rhs + state.y.clone() * coeff;
1267 } else if j - 1 < state.y_history.len() {
1268 let idx = state.y_history.len() - (j - 1);
1269 rhs = rhs + state.y_history[idx].clone() * coeff;
1270 }
1271 }
1272
1273 rhs = rhs + next_dy.clone() * state.h;
1275
1276 for i in 0..y_next.len() {
1278 y_lower[i] = rhs[i] / lower_coeffs[0];
1279 }
1280
1281 let mut max_err = F::zero();
1283 for i in 0..y_next.len() {
1284 let local_err = (y_next[i] - y_lower[i]).abs();
1285 let scale = state.tol_scale[i];
1286 max_err = max_err.max(local_err / scale);
1287 }
1288
1289 let err_order = order + 1; let err_factor = if max_err > F::zero() {
1292 F::from_f64(0.9).unwrap()
1293 * (F::one() / max_err).powf(F::one() / F::from_usize(err_order).unwrap())
1294 } else {
1295 F::from_f64(5.0).unwrap() };
1297
1298 let safety = F::from_f64(0.9).unwrap();
1300 let factor_max = F::from_f64(5.0).unwrap();
1301 let factor_min = F::from_f64(0.2).unwrap();
1302 let factor = safety * err_factor.min(factor_max).max(factor_min);
1303
1304 state.t = next_t;
1306 state.y = y_next;
1307 state.dy = next_dy;
1308
1309 state.h *= factor;
1311
1312 if order < 5 && max_err < opts.rtol && state.y_history.len() >= order {
1315 state.order = (state.order + 1).min(5);
1316 }
1317 else if order > 1 && (max_err > F::from_f64(0.5).unwrap() || iter_count > 2) {
1319 state.order = (state.order - 1).max(1);
1320 }
1321
1322 if iter_count <= 2 {
1324 state.non_stiffness_detected_count += 1;
1325 }
1326
1327 state.jacobian_age += 1;
1329
1330 Ok(true)
1331}