1use crate::error::{IntegrateError, IntegrateResult};
30use crate::IntegrateFloat;
31use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
32
33#[inline(always)]
39fn to_f<F: IntegrateFloat>(v: f64) -> F {
40 F::from_f64(v).unwrap_or_else(F::zero)
41}
42
43pub trait SplitFunction<F: IntegrateFloat>: Send + Sync {
57 fn explicit_part(&self, t: F, y: ArrayView1<F>) -> Array1<F>;
59
60 fn implicit_part(&self, t: F, y: ArrayView1<F>) -> Array1<F>;
62
63 fn jacobian_implicit(&self, t: F, y: ArrayView1<F>) -> Array2<F>;
65
66 fn dimension(&self) -> usize;
68}
69
70#[derive(Debug, Clone)]
76pub struct IMEXConfig<F: IntegrateFloat> {
77 pub dt: F,
79 pub t_end: F,
81 pub rtol: F,
83 pub atol: F,
85 pub max_iter_newton: usize,
87 pub newton_tol: F,
89 pub compute_stiffness: bool,
91}
92
93impl Default for IMEXConfig<f64> {
94 fn default() -> Self {
95 Self {
96 dt: 1e-3,
97 t_end: 1.0,
98 rtol: 1e-6,
99 atol: 1e-9,
100 max_iter_newton: 50,
101 newton_tol: 1e-10,
102 compute_stiffness: false,
103 }
104 }
105}
106
107impl<F: IntegrateFloat> IMEXConfig<F> {
108 pub fn new(dt: F, t_end: F) -> Self {
110 Self {
111 dt,
112 t_end,
113 rtol: to_f(1e-6),
114 atol: to_f(1e-9),
115 max_iter_newton: 50,
116 newton_tol: to_f(1e-10),
117 compute_stiffness: false,
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
128pub struct IMEXResult<F: IntegrateFloat> {
129 pub t: Vec<F>,
131 pub y: Vec<Array1<F>>,
133 pub stiffness_ratio: Vec<F>,
136 pub n_steps: usize,
138 pub n_newton_iters: usize,
140}
141
142fn gaussian_elimination<F: IntegrateFloat>(
150 a: &mut Array2<F>,
151 b: &mut Array1<F>,
152) -> IntegrateResult<Array1<F>> {
153 let n = b.len();
154 if a.shape() != [n, n] {
155 return Err(IntegrateError::DimensionMismatch(format!(
156 "Matrix shape {:?} incompatible with RHS length {}",
157 a.shape(),
158 n
159 )));
160 }
161
162 for col in 0..n {
164 let mut max_row = col;
166 let mut max_val = a[[col, col]].abs();
167 for row in (col + 1)..n {
168 let v = a[[row, col]].abs();
169 if v > max_val {
170 max_val = v;
171 max_row = row;
172 }
173 }
174
175 if max_val < to_f(1e-300) {
176 return Err(IntegrateError::LinearSolveError(
177 "Singular or near-singular matrix in IMEX Newton solve".to_string(),
178 ));
179 }
180
181 if max_row != col {
183 for j in col..n {
184 let tmp = a[[col, j]];
185 a[[col, j]] = a[[max_row, j]];
186 a[[max_row, j]] = tmp;
187 }
188 b.swap(col, max_row);
189 }
190
191 let pivot = a[[col, col]];
193 for row in (col + 1)..n {
194 let factor = a[[row, col]] / pivot;
195 for j in col..n {
196 let update = factor * a[[col, j]];
197 a[[row, j]] -= update;
198 }
199 let bupdate = factor * b[col];
200 b[row] -= bupdate;
201 }
202 }
203
204 let mut x = Array1::<F>::zeros(n);
206 for i in (0..n).rev() {
207 let mut sum = b[i];
208 for j in (i + 1)..n {
209 let ax = a[[i, j]] * x[j];
210 sum -= ax;
211 }
212 x[i] = sum / a[[i, i]];
213 }
214
215 Ok(x)
216}
217
218fn solve_imex_linear<F: IntegrateFloat>(
222 jac: &Array2<F>,
223 rhs: &Array1<F>,
224 alpha: F,
225 dt: F,
226) -> IntegrateResult<Array1<F>> {
227 let n = rhs.len();
228 let mut mat = Array2::<F>::zeros((n, n));
229 for i in 0..n {
231 for j in 0..n {
232 mat[[i, j]] = if i == j {
233 alpha - dt * jac[[i, j]]
234 } else {
235 F::zero() - dt * jac[[i, j]]
236 };
237 }
238 }
239 let mut rhs_copy = rhs.clone();
240 gaussian_elimination(&mut mat, &mut rhs_copy)
241}
242
243fn newton_solve_implicit<F, Sys>(
252 sys: &Sys,
253 t: F,
254 y_prev: &Array1<F>,
255 explicit_term: &Array1<F>,
256 dt: F,
257 cfg: &IMEXConfig<F>,
258) -> IntegrateResult<(Array1<F>, usize)>
259where
260 F: IntegrateFloat,
261 Sys: SplitFunction<F>,
262{
263 let n = y_prev.len();
264 let mut y = y_prev.clone();
265 let mut n_iters;
266
267 for iter in 0..cfg.max_iter_newton {
268 let f_i = sys.implicit_part(t, y.view());
269 let mut residual = Array1::<F>::zeros(n);
271 for i in 0..n {
272 residual[i] = y[i] - y_prev[i] - dt * f_i[i] - explicit_term[i];
273 }
274
275 let res_norm = residual
277 .iter()
278 .fold(F::zero(), |acc, &r| acc + r * r)
279 .sqrt();
280 if res_norm < cfg.newton_tol {
281 n_iters = iter + 1;
282 return Ok((y, n_iters));
283 }
284
285 let jac = sys.jacobian_implicit(t, y.view());
287 let neg_res: Array1<F> = residual.mapv(|r| F::zero() - r);
289 let delta = solve_imex_linear(&jac, &neg_res, F::one(), dt)?;
290
291 for i in 0..n {
293 y[i] += delta[i];
294 }
295 }
296
297 n_iters = cfg.max_iter_newton;
299 Err(IntegrateError::ConvergenceError(format!(
300 "IMEX Newton solver did not converge in {} iterations",
301 cfg.max_iter_newton
302 )))
303 .or(Ok((y, n_iters)))
304}
305
306pub fn imex_euler<F, Sys>(
329 sys: &Sys,
330 t0: F,
331 y0: Array1<F>,
332 cfg: &IMEXConfig<F>,
333) -> IntegrateResult<IMEXResult<F>>
334where
335 F: IntegrateFloat,
336 Sys: SplitFunction<F>,
337{
338 let n = sys.dimension();
339 if y0.len() != n {
340 return Err(IntegrateError::DimensionMismatch(format!(
341 "Initial condition length {} != system dimension {}",
342 y0.len(),
343 n
344 )));
345 }
346
347 let dt = cfg.dt;
348 let mut t = t0;
349 let mut y = y0.clone();
350
351 let mut ts = vec![t];
352 let mut ys = vec![y0];
353 let mut stiff_ratios: Vec<F> = Vec::new();
354 let mut n_steps = 0usize;
355 let mut total_newton = 0usize;
356
357 while t < cfg.t_end - dt * to_f(0.5) {
358 let step = if t + dt > cfg.t_end {
360 cfg.t_end - t
361 } else {
362 dt
363 };
364 let t_next = t + step;
365
366 let f_e = sys.explicit_part(t, y.view());
368 let mut y_star = Array1::<F>::zeros(n);
369 for i in 0..n {
370 y_star[i] = y[i] + step * f_e[i];
371 }
372
373 let zero_expl = Array1::<F>::zeros(n);
377 match newton_solve_implicit(sys, t_next, &y_star, &zero_expl, step, cfg) {
378 Ok((y_new, iters)) => {
379 total_newton += iters;
380 y = y_new.clone();
381 t = t_next;
382 ts.push(t);
383 ys.push(y_new);
384 n_steps += 1;
385
386 if cfg.compute_stiffness {
387 stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
388 }
389 }
390 Err(e) => return Err(e),
391 }
392 }
393
394 Ok(IMEXResult {
395 t: ts,
396 y: ys,
397 stiffness_ratio: stiff_ratios,
398 n_steps,
399 n_newton_iters: total_newton,
400 })
401}
402
403pub fn imex_midpoint<F, Sys>(
424 sys: &Sys,
425 t0: F,
426 y0: Array1<F>,
427 cfg: &IMEXConfig<F>,
428) -> IntegrateResult<IMEXResult<F>>
429where
430 F: IntegrateFloat,
431 Sys: SplitFunction<F>,
432{
433 let n = sys.dimension();
434 if y0.len() != n {
435 return Err(IntegrateError::DimensionMismatch(format!(
436 "Initial condition length {} != system dimension {}",
437 y0.len(),
438 n
439 )));
440 }
441
442 let dt = cfg.dt;
443 let mut t = t0;
444 let mut y = y0.clone();
445
446 let mut ts = vec![t];
447 let mut ys = vec![y0];
448 let mut stiff_ratios: Vec<F> = Vec::new();
449 let mut n_steps = 0usize;
450 let mut total_newton = 0usize;
451
452 while t < cfg.t_end - dt * to_f(0.5) {
453 let step = if t + dt > cfg.t_end {
454 cfg.t_end - t
455 } else {
456 dt
457 };
458 let t_mid = t + step * to_f(0.5);
459
460 let f_e = sys.explicit_part(t, y.view());
462
463 let mut expl_term = Array1::<F>::zeros(n);
465 for i in 0..n {
466 expl_term[i] = step * f_e[i];
467 }
468
469 let y_n = y.clone();
473 let mut u = y_n.clone();
474 for i in 0..n {
476 u[i] += expl_term[i];
477 }
478
479 let mut n_iters_step = 0usize;
480 let mut converged = false;
481 for _iter in 0..cfg.max_iter_newton {
482 let mut y_mid = Array1::<F>::zeros(n);
484 for i in 0..n {
485 y_mid[i] = (y_n[i] + u[i]) * to_f(0.5);
486 }
487
488 let f_i_mid = sys.implicit_part(t_mid, y_mid.view());
489
490 let mut res = Array1::<F>::zeros(n);
492 for i in 0..n {
493 res[i] = u[i] - y_n[i] - step * f_i_mid[i] - expl_term[i];
494 }
495
496 let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
497 if res_norm < cfg.newton_tol {
498 n_iters_step = _iter + 1;
499 converged = true;
500 break;
501 }
502
503 let jac = sys.jacobian_implicit(t_mid, y_mid.view());
505 let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
506
507 let mut mat = Array2::<F>::zeros((n, n));
509 for i in 0..n {
510 for j in 0..n {
511 mat[[i, j]] = if i == j {
512 F::one() - step * to_f(0.5) * jac[[i, j]]
513 } else {
514 F::zero() - step * to_f(0.5) * jac[[i, j]]
515 };
516 }
517 }
518 let mut rhs_copy = neg_res;
519 let delta = gaussian_elimination(&mut mat, &mut rhs_copy)?;
520
521 for i in 0..n {
522 u[i] += delta[i];
523 }
524 }
525
526 if !converged {
527 n_iters_step = cfg.max_iter_newton;
528 }
529
530 total_newton += n_iters_step;
531 y = u.clone();
532 t += step;
533 ts.push(t);
534 ys.push(u);
535 n_steps += 1;
536
537 if cfg.compute_stiffness {
538 stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
539 }
540 }
541
542 Ok(IMEXResult {
543 t: ts,
544 y: ys,
545 stiffness_ratio: stiff_ratios,
546 n_steps,
547 n_newton_iters: total_newton,
548 })
549}
550
551pub fn imex_bdf2<F, Sys>(
571 sys: &Sys,
572 t0: F,
573 y0: Array1<F>,
574 cfg: &IMEXConfig<F>,
575) -> IntegrateResult<IMEXResult<F>>
576where
577 F: IntegrateFloat,
578 Sys: SplitFunction<F>,
579{
580 let n = sys.dimension();
581 if y0.len() != n {
582 return Err(IntegrateError::DimensionMismatch(format!(
583 "Initial condition length {} != system dimension {}",
584 y0.len(),
585 n
586 )));
587 }
588
589 let dt = cfg.dt;
590
591 let f_e0 = sys.explicit_part(t0, y0.view());
593 let mut y_star = Array1::<F>::zeros(n);
594 for i in 0..n {
595 y_star[i] = y0[i] + dt * f_e0[i];
596 }
597 let zero_expl = Array1::<F>::zeros(n);
598 let (y1, newton0) = newton_solve_implicit(sys, t0 + dt, &y_star, &zero_expl, dt, cfg)
599 .unwrap_or_else(|_| (y_star.clone(), cfg.max_iter_newton));
600
601 let t1 = t0 + dt;
602
603 let mut ts = vec![t0, t1];
604 let mut ys = vec![y0.clone(), y1.clone()];
605 let mut stiff_ratios: Vec<F> = Vec::new();
606 let mut n_steps = 1usize;
607 let mut total_newton = newton0;
608
609 let mut y_prev = y0.clone();
610 let mut f_e_prev = f_e0;
611 let mut y_curr = y1;
612 let mut t_curr = t1;
613
614 while t_curr < cfg.t_end - dt * to_f(0.5) {
616 let step = if t_curr + dt > cfg.t_end {
617 cfg.t_end - t_curr
618 } else {
619 dt
620 };
621 let t_next = t_curr + step;
622
623 let f_e_curr = sys.explicit_part(t_curr, y_curr.view());
624
625 let mut expl_rhs = Array1::<F>::zeros(n);
628 for i in 0..n {
629 expl_rhs[i] = step * (to_f::<F>(2.0) * f_e_curr[i] - f_e_prev[i]);
630 }
631
632 let mut rhs_const = Array1::<F>::zeros(n);
635 for i in 0..n {
636 rhs_const[i] = to_f::<F>(2.0) * y_curr[i] - to_f::<F>(0.5) * y_prev[i] + expl_rhs[i];
637 }
638
639 let mut u = y_curr.clone();
643 let mut n_iters_step = 0usize;
644 let three_half = to_f::<F>(1.5);
645
646 for _iter in 0..cfg.max_iter_newton {
647 let f_i = sys.implicit_part(t_next, u.view());
648 let mut res = Array1::<F>::zeros(n);
649 for i in 0..n {
650 res[i] = three_half * u[i] - step * f_i[i] - rhs_const[i];
651 }
652
653 let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
654 if res_norm < cfg.newton_tol {
655 n_iters_step = _iter + 1;
656 break;
657 }
658
659 let jac = sys.jacobian_implicit(t_next, u.view());
660 let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
661 let delta = solve_imex_linear(&jac, &neg_res, three_half, step)?;
663
664 for i in 0..n {
665 u[i] += delta[i];
666 }
667
668 if _iter + 1 == cfg.max_iter_newton {
669 n_iters_step = cfg.max_iter_newton;
670 }
671 }
672
673 total_newton += n_iters_step;
674
675 y_prev = y_curr;
677 f_e_prev = f_e_curr;
678 y_curr = u.clone();
679 t_curr = t_next;
680
681 ts.push(t_curr);
682 ys.push(u);
683 n_steps += 1;
684
685 if cfg.compute_stiffness {
686 stiff_ratios.push(estimate_stiffness_ratio(sys, t_curr, &y_curr, step)?);
687 }
688 }
689
690 Ok(IMEXResult {
691 t: ts,
692 y: ys,
693 stiffness_ratio: stiff_ratios,
694 n_steps,
695 n_newton_iters: total_newton,
696 })
697}
698
699pub fn imex_ark_ssp2<F, Sys>(
731 sys: &Sys,
732 t0: F,
733 y0: Array1<F>,
734 cfg: &IMEXConfig<F>,
735) -> IntegrateResult<IMEXResult<F>>
736where
737 F: IntegrateFloat,
738 Sys: SplitFunction<F>,
739{
740 let n = sys.dimension();
741 if y0.len() != n {
742 return Err(IntegrateError::DimensionMismatch(format!(
743 "Initial condition length {} != system dimension {}",
744 y0.len(),
745 n
746 )));
747 }
748
749 let gamma: F = to_f(1.0 - 1.0 / std::f64::consts::SQRT_2);
751 let one_minus_gamma: F = F::one() - gamma;
752
753 let dt = cfg.dt;
754 let mut t = t0;
755 let mut y = y0.clone();
756
757 let mut ts = vec![t];
758 let mut ys = vec![y0];
759 let mut stiff_ratios: Vec<F> = Vec::new();
760 let mut n_steps = 0usize;
761 let mut total_newton = 0usize;
762
763 while t < cfg.t_end - dt * to_f(0.5) {
764 let step = if t + dt > cfg.t_end {
765 cfg.t_end - t
766 } else {
767 dt
768 };
769
770 let t_stage1 = t + gamma * step;
776
777 let mut y1_i = y.clone();
779 let mut n_iter1 = 0usize;
780 for _it in 0..cfg.max_iter_newton {
781 let f_i1 = sys.implicit_part(t_stage1, y1_i.view());
782 let mut res = Array1::<F>::zeros(n);
783 for i in 0..n {
784 res[i] = y1_i[i] - step * gamma * f_i1[i] - y[i];
785 }
786 let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
787 if res_norm < cfg.newton_tol {
788 n_iter1 = _it + 1;
789 break;
790 }
791 let jac = sys.jacobian_implicit(t_stage1, y1_i.view());
792 let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
793 let delta = solve_imex_linear(&jac, &neg_res, F::one(), step * gamma)?;
794 for i in 0..n {
795 y1_i[i] += delta[i];
796 }
797 if _it + 1 == cfg.max_iter_newton {
798 n_iter1 = cfg.max_iter_newton;
799 }
800 }
801 total_newton += n_iter1;
802
803 let k1_e = sys.explicit_part(t, y.view());
805 let k1_i = sys.implicit_part(t_stage1, y1_i.view());
807
808 let t_stage2 = t + step; let mut y2_e = Array1::<F>::zeros(n);
815 for i in 0..n {
816 y2_e[i] = y[i] + step * k1_e[i];
817 }
818
819 let mut y2_i = y.clone();
821 for i in 0..n {
823 y2_i[i] = y[i] + step * one_minus_gamma * k1_i[i];
824 }
825
826 let mut n_iter2 = 0usize;
827 for _it in 0..cfg.max_iter_newton {
828 let f_i2 = sys.implicit_part(t_stage2, y2_i.view());
829 let mut res = Array1::<F>::zeros(n);
830 for i in 0..n {
831 res[i] = y2_i[i] - step * one_minus_gamma * k1_i[i] - step * gamma * f_i2[i] - y[i];
832 }
833 let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
834 if res_norm < cfg.newton_tol {
835 n_iter2 = _it + 1;
836 break;
837 }
838 let jac = sys.jacobian_implicit(t_stage2, y2_i.view());
839 let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
840 let delta = solve_imex_linear(&jac, &neg_res, F::one(), step * gamma)?;
841 for i in 0..n {
842 y2_i[i] += delta[i];
843 }
844 if _it + 1 == cfg.max_iter_newton {
845 n_iter2 = cfg.max_iter_newton;
846 }
847 }
848 total_newton += n_iter2;
849
850 let k2_e = sys.explicit_part(t + step, y2_e.view()); let k2_i = sys.implicit_part(t_stage2, y2_i.view());
852
853 let mut y_new = Array1::<F>::zeros(n);
856 for i in 0..n {
857 y_new[i] = y[i]
858 + step * to_f(0.5) * (k1_e[i] + k2_e[i])
859 + step * to_f(0.5) * (k1_i[i] + k2_i[i]);
860 }
861
862 y = y_new.clone();
863 t += step;
864 ts.push(t);
865 ys.push(y_new);
866 n_steps += 1;
867
868 if cfg.compute_stiffness {
869 stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
870 }
871 }
872
873 Ok(IMEXResult {
874 t: ts,
875 y: ys,
876 stiffness_ratio: stiff_ratios,
877 n_steps,
878 n_newton_iters: total_newton,
879 })
880}
881
882pub fn imex_ark_ssp3<F, Sys>(
909 sys: &Sys,
910 t0: F,
911 y0: Array1<F>,
912 cfg: &IMEXConfig<F>,
913) -> IntegrateResult<IMEXResult<F>>
914where
915 F: IntegrateFloat,
916 Sys: SplitFunction<F>,
917{
918 let n = sys.dimension();
919 if y0.len() != n {
920 return Err(IntegrateError::DimensionMismatch(format!(
921 "Initial condition length {} != system dimension {}",
922 y0.len(),
923 n
924 )));
925 }
926
927 let gamma: F = to_f((3.0 + 3.0_f64.sqrt()) / 6.0);
929 let two_gamma = gamma * to_f(2.0);
930 let one_minus_two_gamma = F::one() - two_gamma;
931 let half_minus_gamma: F = to_f::<F>(0.5) - gamma;
932
933 let dt = cfg.dt;
934 let mut t = t0;
935 let mut y = y0.clone();
936
937 let mut ts = vec![t];
938 let mut ys = vec![y0];
939 let mut stiff_ratios: Vec<F> = Vec::new();
940 let mut n_steps = 0usize;
941 let mut total_newton = 0usize;
942
943 while t < cfg.t_end - dt * to_f(0.5) {
944 let step = if t + dt > cfg.t_end {
945 cfg.t_end - t
946 } else {
947 dt
948 };
949
950 let t_i1 = t + gamma * step;
952 let k1_e = sys.explicit_part(t, y.view());
953
954 let (y1_i, ni1) =
956 solve_sdirk_stage(sys, t_i1, &y, &Array1::<F>::zeros(n), gamma, step, cfg)?;
957 total_newton += ni1;
958 let k1_i = sys.implicit_part(t_i1, y1_i.view());
959
960 let t_i2 = t + (F::one() - gamma) * step;
962 let mut y2_e = Array1::<F>::zeros(n);
964 for i in 0..n {
965 y2_e[i] = y[i] + step * k1_e[i];
966 }
967 let k2_e = sys.explicit_part(t + step, y2_e.view());
968
969 let mut acc2 = Array1::<F>::zeros(n);
971 for i in 0..n {
972 acc2[i] = step * one_minus_two_gamma * k1_i[i];
973 }
974 let (y2_i, ni2) = solve_sdirk_stage(sys, t_i2, &y, &acc2, gamma, step, cfg)?;
975 total_newton += ni2;
976 let k2_i = sys.implicit_part(t_i2, y2_i.view());
977
978 let t_i3 = t + to_f::<F>(0.5) * step;
980 let mut y3_e = Array1::<F>::zeros(n);
982 for i in 0..n {
983 y3_e[i] = y[i] + step * (to_f::<F>(0.25) * k1_e[i] + to_f::<F>(0.25) * k2_e[i]);
984 }
985 let k3_e = sys.explicit_part(t + to_f::<F>(0.5) * step, y3_e.view());
986
987 let mut acc3 = Array1::<F>::zeros(n);
989 for i in 0..n {
990 acc3[i] = step * half_minus_gamma * k1_i[i];
991 }
992 let (y3_i, ni3) = solve_sdirk_stage(sys, t_i3, &y, &acc3, gamma, step, cfg)?;
993 total_newton += ni3;
994 let k3_i = sys.implicit_part(t_i3, y3_i.view());
995
996 let mut y_new = Array1::<F>::zeros(n);
999 for i in 0..n {
1000 y_new[i] = y[i]
1001 + step
1002 * (to_f::<F>(1.0 / 6.0) * (k1_e[i] + k1_i[i])
1003 + to_f::<F>(1.0 / 6.0) * (k2_e[i] + k2_i[i])
1004 + to_f::<F>(2.0 / 3.0) * (k3_e[i] + k3_i[i]));
1005 }
1006
1007 y = y_new.clone();
1008 t += step;
1009 ts.push(t);
1010 ys.push(y_new);
1011 n_steps += 1;
1012
1013 if cfg.compute_stiffness {
1014 stiff_ratios.push(estimate_stiffness_ratio(sys, t, &y, step)?);
1015 }
1016 }
1017
1018 Ok(IMEXResult {
1019 t: ts,
1020 y: ys,
1021 stiffness_ratio: stiff_ratios,
1022 n_steps,
1023 n_newton_iters: total_newton,
1024 })
1025}
1026
1027fn solve_sdirk_stage<F, Sys>(
1035 sys: &Sys,
1036 t_stage: F,
1037 y_base: &Array1<F>,
1038 acc: &Array1<F>,
1039 gamma: F,
1040 step: F,
1041 cfg: &IMEXConfig<F>,
1042) -> IntegrateResult<(Array1<F>, usize)>
1043where
1044 F: IntegrateFloat,
1045 Sys: SplitFunction<F>,
1046{
1047 let n = y_base.len();
1048 let mut y = Array1::<F>::zeros(n);
1049 for i in 0..n {
1050 y[i] = y_base[i] + acc[i]; }
1052
1053 let alpha = step * gamma;
1054 let mut n_iters = 0usize;
1055
1056 for _it in 0..cfg.max_iter_newton {
1057 let f_i = sys.implicit_part(t_stage, y.view());
1058 let mut res = Array1::<F>::zeros(n);
1059 for i in 0..n {
1060 res[i] = y[i] - acc[i] - alpha * f_i[i] - y_base[i];
1061 }
1062 let res_norm = res.iter().fold(F::zero(), |acc, &r| acc + r * r).sqrt();
1063 if res_norm < cfg.newton_tol {
1064 n_iters = _it + 1;
1065 return Ok((y, n_iters));
1066 }
1067 let jac = sys.jacobian_implicit(t_stage, y.view());
1068 let neg_res: Array1<F> = res.mapv(|r| F::zero() - r);
1069 let delta = solve_imex_linear(&jac, &neg_res, F::one(), alpha)?;
1070 for i in 0..n {
1071 y[i] += delta[i];
1072 }
1073 if _it + 1 == cfg.max_iter_newton {
1074 n_iters = cfg.max_iter_newton;
1075 }
1076 }
1077
1078 Ok((y, n_iters))
1079}
1080
1081fn estimate_stiffness_ratio<F, Sys>(sys: &Sys, t: F, y: &Array1<F>, _dt: F) -> IntegrateResult<F>
1088where
1089 F: IntegrateFloat,
1090 Sys: SplitFunction<F>,
1091{
1092 let n = sys.dimension();
1093 let j_i = sys.jacobian_implicit(t, y.view());
1094
1095 let mut rho_i = F::zero();
1097 for row in 0..n {
1098 let diag = j_i[[row, row]].abs();
1099 let off_sum: F = (0..n)
1100 .filter(|&j| j != row)
1101 .fold(F::zero(), |s, j| s + j_i[[row, j]].abs());
1102 let r = diag + off_sum;
1103 if r > rho_i {
1104 rho_i = r;
1105 }
1106 }
1107
1108 let eps: F = to_f(1e-7);
1110 let f_base = sys.explicit_part(t, y.view());
1111 let mut rho_e = F::zero();
1112 for col in 0..n {
1113 let mut y_pert = y.clone();
1114 y_pert[col] += eps;
1115 let f_pert = sys.explicit_part(t, y_pert.view());
1116 let col_norm = (0..n)
1117 .fold(F::zero(), |s, row| {
1118 let diff = (f_pert[row] - f_base[row]) / eps;
1119 s + diff * diff
1120 })
1121 .sqrt();
1122 if col_norm > rho_e {
1123 rho_e = col_norm;
1124 }
1125 }
1126
1127 if rho_e < to_f(1e-300) {
1128 Ok(to_f(1.0))
1129 } else {
1130 Ok(rho_i / rho_e)
1131 }
1132}
1133
1134#[cfg(test)]
1139mod tests {
1140 use super::*;
1141 use scirs2_core::ndarray::{array, Array2};
1142
1143 struct StiffLinear {
1146 lambda_stiff: f64,
1147 lambda_nonstiff: f64,
1148 }
1149
1150 impl SplitFunction<f64> for StiffLinear {
1151 fn explicit_part(&self, _t: f64, y: ArrayView1<f64>) -> Array1<f64> {
1152 array![self.lambda_nonstiff * y[0]]
1153 }
1154
1155 fn implicit_part(&self, _t: f64, y: ArrayView1<f64>) -> Array1<f64> {
1156 array![self.lambda_stiff * y[0]]
1157 }
1158
1159 fn jacobian_implicit(&self, _t: f64, _y: ArrayView1<f64>) -> Array2<f64> {
1160 let mut j = Array2::<f64>::zeros((1, 1));
1161 j[[0, 0]] = self.lambda_stiff;
1162 j
1163 }
1164
1165 fn dimension(&self) -> usize {
1166 1
1167 }
1168 }
1169
1170 #[test]
1171 fn test_imex_euler_decay() {
1172 let sys = StiffLinear {
1175 lambda_stiff: -10.0,
1176 lambda_nonstiff: 0.0,
1177 };
1178 let cfg = IMEXConfig {
1179 dt: 0.01,
1180 t_end: 1.0,
1181 newton_tol: 1e-12,
1182 ..IMEXConfig::default()
1183 };
1184 let result = imex_euler(&sys, 0.0, array![1.0], &cfg).expect("IMEX Euler failed");
1185
1186 let t_final = *result.t.last().expect("no time points");
1187 let y_final = result.y.last().expect("no solution")[0];
1188 let exact = (-10.0_f64 * t_final).exp();
1189
1190 assert!(
1191 (y_final - exact).abs() < 0.05,
1192 "IMEX Euler: y={} exact={} err={}",
1193 y_final,
1194 exact,
1195 (y_final - exact).abs()
1196 );
1197 }
1198
1199 #[test]
1200 fn test_imex_bdf2_decay() {
1201 let sys = StiffLinear {
1202 lambda_stiff: -5.0,
1203 lambda_nonstiff: -1.0,
1204 };
1205 let cfg = IMEXConfig {
1206 dt: 0.01,
1207 t_end: 0.5,
1208 newton_tol: 1e-12,
1209 ..IMEXConfig::default()
1210 };
1211 let result = imex_bdf2(&sys, 0.0, array![1.0], &cfg).expect("IMEX BDF2 failed");
1212
1213 let t_final = *result.t.last().expect("no time points");
1214 let y_final = result.y.last().expect("no solution")[0];
1215 let exact = (-6.0_f64 * t_final).exp();
1216
1217 assert!(
1218 (y_final - exact).abs() < 0.02,
1219 "IMEX BDF2: y={} exact={} err={}",
1220 y_final,
1221 exact,
1222 (y_final - exact).abs()
1223 );
1224 }
1225
1226 #[test]
1227 fn test_imex_ark_ssp2_decay() {
1228 let sys = StiffLinear {
1229 lambda_stiff: -5.0,
1230 lambda_nonstiff: -1.0,
1231 };
1232 let cfg = IMEXConfig {
1233 dt: 0.01,
1234 t_end: 0.5,
1235 newton_tol: 1e-12,
1236 ..IMEXConfig::default()
1237 };
1238 let result = imex_ark_ssp2(&sys, 0.0, array![1.0], &cfg).expect("IMEX ARK SSP2 failed");
1239
1240 let t_final = *result.t.last().expect("no time points");
1241 let y_final = result.y.last().expect("no solution")[0];
1242 let exact = (-6.0_f64 * t_final).exp();
1243
1244 assert!(
1245 (y_final - exact).abs() < 0.01,
1246 "IMEX ARK SSP2: y={} exact={} err={}",
1247 y_final,
1248 exact,
1249 (y_final - exact).abs()
1250 );
1251 }
1252
1253 #[test]
1254 fn test_imex_ark_ssp3_decay() {
1255 let sys = StiffLinear {
1256 lambda_stiff: -5.0,
1257 lambda_nonstiff: -1.0,
1258 };
1259 let cfg = IMEXConfig {
1260 dt: 0.01,
1261 t_end: 0.5,
1262 newton_tol: 1e-12,
1263 ..IMEXConfig::default()
1264 };
1265 let result = imex_ark_ssp3(&sys, 0.0, array![1.0], &cfg).expect("IMEX ARK SSP3 failed");
1266
1267 let t_final = *result.t.last().expect("no time points");
1268 let y_final = result.y.last().expect("no solution")[0];
1269 let exact = (-6.0_f64 * t_final).exp();
1270
1271 assert!(
1272 (y_final - exact).abs() < 0.01,
1273 "IMEX ARK SSP3: y={} exact={} err={}",
1274 y_final,
1275 exact,
1276 (y_final - exact).abs()
1277 );
1278 }
1279
1280 #[test]
1281 fn test_imex_midpoint_decay() {
1282 let sys = StiffLinear {
1283 lambda_stiff: -5.0,
1284 lambda_nonstiff: -1.0,
1285 };
1286 let cfg = IMEXConfig {
1287 dt: 0.01,
1288 t_end: 0.5,
1289 newton_tol: 1e-12,
1290 ..IMEXConfig::default()
1291 };
1292 let result = imex_midpoint(&sys, 0.0, array![1.0], &cfg).expect("IMEX Midpoint failed");
1293
1294 let t_final = *result.t.last().expect("no time points");
1295 let y_final = result.y.last().expect("no solution")[0];
1296 let exact = (-6.0_f64 * t_final).exp();
1297
1298 assert!(
1299 (y_final - exact).abs() < 0.01,
1300 "IMEX Midpoint: y={} exact={} err={}",
1301 y_final,
1302 exact,
1303 (y_final - exact).abs()
1304 );
1305 }
1306}