1use crate::common::IntegrateFloat;
8use crate::error::IntegrateResult;
9use crate::ode::types::{ODEOptions, ODEResult};
10use crate::ode::utils::common::{estimate_initial_step, ODEState, StepResult};
11use scirs2_core::ndarray::{Array1, ArrayView1};
12
13#[cfg(feature = "simd")]
14use crate::ode::utils::simd_ops::SimdOdeOps;
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17#[cfg(feature = "simd")]
34#[allow(dead_code)]
35pub fn simd_rk4_method<F, Func>(
36 f: Func,
37 t_span: [F; 2],
38 y0: Array1<F>,
39 opts: ODEOptions<F>,
40) -> IntegrateResult<ODEResult<F>>
41where
42 F: IntegrateFloat + SimdUnifiedOps,
43 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
44{
45 let [t_start, t_end] = t_span;
46 let n_dim = y0.len();
47
48 let h = opts.h0.unwrap_or_else(|| {
50 let dy0 = f(t_start, y0.view());
51 let tol = opts.atol + opts.rtol;
52 estimate_initial_step(&f, t_start, &y0, &dy0, tol, t_end)
53 });
54
55 let mut t_values = vec![t_start];
57 let mut y_values = vec![y0.clone()];
58
59 let mut t = t_start;
60 let mut y = y0;
61 let mut steps = 0;
62 let mut func_evals = 0;
63
64 while t < t_end {
65 let h_current = if t + h > t_end { t_end - t } else { h };
67
68 let (y_new, n_evals) = simd_rk4_step(&f, t, &y.view(), h_current)?;
70 func_evals += n_evals;
71
72 t += h_current;
74 y = y_new;
75 steps += 1;
76
77 t_values.push(t);
79 y_values.push(y.clone());
80
81 if steps > 1_000_000 {
83 return Err(crate::error::IntegrateError::ConvergenceError(
84 "Maximum number of steps exceeded in SIMD RK4 method".to_string(),
85 ));
86 }
87 }
88
89 Ok(ODEResult {
90 t: t_values,
91 y: y_values,
92 n_steps: steps,
93 n_eval: func_evals,
94 n_accepted: steps,
95 n_rejected: 0,
96 n_lu: 0,
97 n_jac: 0,
98 method: crate::ode::types::ODEMethod::RK4,
99 success: true,
100 message: Some("Integration completed successfully".to_string()),
101 })
102}
103
104#[cfg(feature = "simd")]
120#[allow(dead_code)]
121pub fn simd_rk45_method<F, Func>(
122 f: Func,
123 t_span: [F; 2],
124 y0: Array1<F>,
125 opts: ODEOptions<F>,
126) -> IntegrateResult<ODEResult<F>>
127where
128 F: IntegrateFloat + SimdUnifiedOps,
129 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
130{
131 let [t_start, t_end] = t_span;
132
133 let mut h = opts.h0.unwrap_or_else(|| {
135 let dy0 = f(t_start, y0.view());
136 let tol = opts.atol + opts.rtol;
137 estimate_initial_step(&f, t_start, &y0, &dy0, tol, t_end)
138 });
139
140 let min_step = opts.min_step.unwrap_or(F::from_f64(1e-12).unwrap());
141 let max_step = opts
142 .max_step
143 .unwrap_or((t_end - t_start) / F::from_f64(10.0).unwrap());
144 let abs_tol = opts.atol;
145 let rel_tol = opts.rtol;
146
147 let mut t_values = vec![t_start];
149 let mut y_values = vec![y0.clone()];
150
151 let mut t = t_start;
152 let mut y = y0;
153 let mut steps = 0;
154 let mut func_evals = 0;
155 let mut rejected_steps = 0;
156
157 while t < t_end {
158 if t + h > t_end {
160 h = t_end - t;
161 }
162
163 h = h.min(max_step).max(min_step);
165
166 let (y_new, y_star, n_evals) = simd_rk45_step(&f, t, &y.view(), h)?;
168 func_evals += n_evals;
169
170 let mut err_norm = F::zero();
172 for i in 0..y_new.len() {
173 let sc = abs_tol + rel_tol * y_new[i].abs();
174 let err = (y_new[i] - y_star[i]).abs() / sc;
175 err_norm = err_norm.max(err);
176 }
177
178 let order = F::from_f64(5.0).unwrap();
180 let exponent = F::one() / (order + F::one());
181 let safety = F::from_f64(0.9).unwrap();
182 let factor = safety * (F::one() / err_norm).powf(exponent);
183 let factor_min = F::from_f64(0.2).unwrap();
184 let factor_max = F::from_f64(5.0).unwrap();
185 let factor = factor.min(factor_max).max(factor_min);
186
187 if err_norm <= F::one() {
188 t += h;
190 y = y_new;
191 steps += 1;
192
193 t_values.push(t);
195 y_values.push(y.clone());
196
197 if err_norm <= F::from_f64(0.1).unwrap() {
199 h *= factor.max(F::from_f64(2.0).unwrap());
200 } else {
201 h *= factor;
202 }
203 } else {
204 rejected_steps += 1;
206 h *= factor.min(F::one());
207
208 if h < min_step {
210 return Err(crate::error::IntegrateError::StepSizeTooSmall(
211 "Step size became too small in SIMD RK45 method".to_string(),
212 ));
213 }
214 }
215
216 if steps > 100_000 {
218 return Err(crate::error::IntegrateError::ConvergenceError(
219 "Maximum number of steps exceeded in SIMD RK45 method".to_string(),
220 ));
221 }
222 }
223
224 Ok(ODEResult {
225 t: t_values,
226 y: y_values,
227 n_steps: steps,
228 n_eval: func_evals,
229 n_accepted: steps - rejected_steps,
230 n_rejected: rejected_steps,
231 n_lu: 0,
232 n_jac: 0,
233 method: crate::ode::types::ODEMethod::RK45,
234 success: true,
235 message: Some("Integration completed successfully".to_string()),
236 })
237}
238
239#[cfg(feature = "simd")]
241#[allow(dead_code)]
242fn simd_rk4_step<F, Func>(
243 f: &Func,
244 t: F,
245 y: &ArrayView1<F>,
246 h: F,
247) -> IntegrateResult<(Array1<F>, usize)>
248where
249 F: IntegrateFloat + SimdUnifiedOps,
250 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
251{
252 let h_half = h * F::from_f64(0.5).unwrap();
253
254 let k1 = f(t, y.to_owned().view());
256
257 let y_temp1 = F::simd_add(y, &F::simd_scalar_mul(&k1.view(), h_half).view());
259 let k2 = f(t + h_half, y_temp1.view());
260
261 let y_temp2 = F::simd_add(y, &F::simd_scalar_mul(&k2.view(), h_half).view());
263 let k3 = f(t + h_half, y_temp2.view());
264
265 let y_temp3 = F::simd_add(y, &F::simd_scalar_mul(&k3.view(), h).view());
267 let k4 = f(t + h, y_temp3.view());
268
269 let c1 = F::one() / F::from_f64(6.0).unwrap();
271 let c2 = F::from_f64(2.0).unwrap() / F::from_f64(6.0).unwrap();
272
273 let term1 = F::simd_scalar_mul(&k1.view(), c1 * h);
275 let term2 = F::simd_scalar_mul(&k2.view(), c2 * h);
276 let term3 = F::simd_scalar_mul(&k3.view(), c2 * h);
277 let term4 = F::simd_scalar_mul(&k4.view(), c1 * h);
278
279 let sum12 = F::simd_add(&term1.view(), &term2.view());
280 let sum34 = F::simd_add(&term3.view(), &term4.view());
281 let weighted_sum = F::simd_add(&sum12.view(), &sum34.view());
282
283 let y_new = F::simd_add(y, &weighted_sum.view());
284
285 Ok((y_new, 4)) }
287
288#[cfg(feature = "simd")]
290#[allow(dead_code)]
291fn simd_rk45_step<F, Func>(
292 f: &Func,
293 t: F,
294 y: &ArrayView1<F>,
295 h: F,
296) -> IntegrateResult<(Array1<F>, Array1<F>, usize)>
297where
298 F: IntegrateFloat + SimdUnifiedOps,
299 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
300{
301 let a21 = F::from_f64(1.0 / 5.0).unwrap();
303 let a31 = F::from_f64(3.0 / 40.0).unwrap();
304 let a32 = F::from_f64(9.0 / 40.0).unwrap();
305 let a41 = F::from_f64(44.0 / 45.0).unwrap();
306 let a42 = F::from_f64(-56.0 / 15.0).unwrap();
307 let a43 = F::from_f64(32.0 / 9.0).unwrap();
308 let a51 = F::from_f64(19372.0 / 6561.0).unwrap();
309 let a52 = F::from_f64(-25360.0 / 2187.0).unwrap();
310 let a53 = F::from_f64(64448.0 / 6561.0).unwrap();
311 let a54 = F::from_f64(-212.0 / 729.0).unwrap();
312 let a61 = F::from_f64(9017.0 / 3168.0).unwrap();
313 let a62 = F::from_f64(-355.0 / 33.0).unwrap();
314 let a63 = F::from_f64(46732.0 / 5247.0).unwrap();
315 let a64 = F::from_f64(49.0 / 176.0).unwrap();
316 let a65 = F::from_f64(-5103.0 / 18656.0).unwrap();
317
318 let k1 = f(t, y.to_owned().view());
320
321 let y2 = F::simd_add(y, &F::simd_scalar_mul(&k1.view(), h * a21).view());
323 let k2 = f(t + h * a21, y2.view());
324
325 let term1 = F::simd_scalar_mul(&k1.view(), a31 * h);
327 let term2 = F::simd_scalar_mul(&k2.view(), a32 * h);
328 let y3 = F::simd_add(y, &F::simd_add(&term1.view(), &term2.view()).view());
329 let k3 = f(t + h * F::from_f64(3.0 / 10.0).unwrap(), y3.view());
330
331 let t1 = F::simd_scalar_mul(&k1.view(), a41 * h);
333 let t2 = F::simd_scalar_mul(&k2.view(), a42 * h);
334 let t3 = F::simd_scalar_mul(&k3.view(), a43 * h);
335 let y4 = F::simd_add(
336 y,
337 &F::simd_add(&F::simd_add(&t1.view(), &t2.view()).view(), &t3.view()).view(),
338 );
339 let k4 = f(t + h * F::from_f64(4.0 / 5.0).unwrap(), y4.view());
340
341 let r1 = F::simd_scalar_mul(&k1.view(), a51 * h);
343 let r2 = F::simd_scalar_mul(&k2.view(), a52 * h);
344 let r3 = F::simd_scalar_mul(&k3.view(), a53 * h);
345 let r4 = F::simd_scalar_mul(&k4.view(), a54 * h);
346 let sum1 = F::simd_add(&r1.view(), &r2.view());
347 let sum2 = F::simd_add(&r3.view(), &r4.view());
348 let y5 = F::simd_add(y, &F::simd_add(&sum1.view(), &sum2.view()).view());
349 let k5 = f(t + h * F::from_f64(8.0 / 9.0).unwrap(), y5.view());
350
351 let s1 = F::simd_scalar_mul(&k1.view(), a61 * h);
353 let s2 = F::simd_scalar_mul(&k2.view(), a62 * h);
354 let s3 = F::simd_scalar_mul(&k3.view(), a63 * h);
355 let s4 = F::simd_scalar_mul(&k4.view(), a64 * h);
356 let s5 = F::simd_scalar_mul(&k5.view(), a65 * h);
357 let ssum1 = F::simd_add(&s1.view(), &s2.view());
358 let ssum2 = F::simd_add(&s3.view(), &s4.view());
359 let ssum3 = F::simd_add(&ssum1.view(), &ssum2.view());
360 let y6 = F::simd_add(y, &F::simd_add(&ssum3.view(), &s5.view()).view());
361 let k6 = f(t + h, y6.view());
362
363 let b1 = F::from_f64(35.0 / 384.0).unwrap();
365 let b3 = F::from_f64(500.0 / 1113.0).unwrap();
366 let b4 = F::from_f64(125.0 / 192.0).unwrap();
367 let b5 = F::from_f64(-2187.0 / 6784.0).unwrap();
368 let b6 = F::from_f64(11.0 / 84.0).unwrap();
369
370 let w1 = F::simd_scalar_mul(&k1.view(), b1 * h);
371 let w3 = F::simd_scalar_mul(&k3.view(), b3 * h);
372 let w4 = F::simd_scalar_mul(&k4.view(), b4 * h);
373 let w5 = F::simd_scalar_mul(&k5.view(), b5 * h);
374 let w6 = F::simd_scalar_mul(&k6.view(), b6 * h);
375 let wsum1 = F::simd_add(&w1.view(), &w3.view());
376 let wsum2 = F::simd_add(&w4.view(), &w5.view());
377 let wsum3 = F::simd_add(&wsum1.view(), &wsum2.view());
378 let y_new = F::simd_add(y, &F::simd_add(&wsum3.view(), &w6.view()).view());
379
380 let k7 = f(t + h, y_new.view());
382
383 let b1_star = F::from_f64(5179.0 / 57600.0).unwrap();
385 let b3_star = F::from_f64(7571.0 / 16695.0).unwrap();
386 let b4_star = F::from_f64(393.0 / 640.0).unwrap();
387 let b5_star = F::from_f64(-92097.0 / 339200.0).unwrap();
388 let b6_star = F::from_f64(187.0 / 2100.0).unwrap();
389 let b7_star = F::from_f64(1.0 / 40.0).unwrap();
390
391 let v1 = F::simd_scalar_mul(&k1.view(), b1_star * h);
392 let v3 = F::simd_scalar_mul(&k3.view(), b3_star * h);
393 let v4 = F::simd_scalar_mul(&k4.view(), b4_star * h);
394 let v5 = F::simd_scalar_mul(&k5.view(), b5_star * h);
395 let v6 = F::simd_scalar_mul(&k6.view(), b6_star * h);
396 let v7 = F::simd_scalar_mul(&k7.view(), b7_star * h);
397 let vsum1 = F::simd_add(&v1.view(), &v3.view());
398 let vsum2 = F::simd_add(&v4.view(), &v5.view());
399 let vsum3 = F::simd_add(&v6.view(), &v7.view());
400 let vsum4 = F::simd_add(&vsum1.view(), &vsum2.view());
401 let y_star = F::simd_add(y, &F::simd_add(&vsum4.view(), &vsum3.view()).view());
402
403 Ok((y_new, y_star, 7)) }
406
407#[cfg(not(feature = "simd"))]
409#[allow(dead_code)]
410pub fn simd_rk4_method<F, Func>(
411 f: Func,
412 t_span: [F; 2],
413 y0: Array1<F>,
414 opts: ODEOptions<F>,
415) -> IntegrateResult<ODEResult<F>>
416where
417 F: IntegrateFloat + SimdUnifiedOps,
418 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
419{
420 let h = opts.h0.unwrap_or_else(|| {
422 let dy0 = f(t_span[0], y0.view());
423 let tol = opts.atol + opts.rtol;
424 estimate_initial_step(&f, t_span[0], &y0, &dy0, tol, t_span[1])
425 });
426 crate::ode::methods::explicit::rk4_method(f, t_span, y0, h, opts)
427}
428
429#[cfg(not(feature = "simd"))]
430#[allow(dead_code)]
431pub fn simd_rk45_method<F, Func>(
432 f: Func,
433 t_span: [F; 2],
434 y0: Array1<F>,
435 opts: ODEOptions<F>,
436) -> IntegrateResult<ODEResult<F>>
437where
438 F: IntegrateFloat + SimdUnifiedOps,
439 Func: Fn(F, ArrayView1<F>) -> Array1<F>,
440{
441 crate::ode::methods::adaptive::rk45_method(f, t_span, y0, opts)
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use approx::assert_relative_eq;
449 use scirs2_core::ndarray::arr1;
450
451 #[test]
452 #[cfg(feature = "simd")]
453 fn test_simd_rk4_simple() {
454 let f = |_t: f64, y: ArrayView1<f64>| -> Array1<f64> { -y.to_owned() };
456
457 let y0 = arr1(&[1.0]);
458 let t_span = [0.0, 1.0];
459 let opts = ODEOptions {
460 h0: Some(0.1),
461 ..Default::default()
462 };
463
464 let result = simd_rk4_method(f, t_span, y0, opts).unwrap();
465
466 let final_value = result.y.last().unwrap()[0];
468 let exact = (-1.0_f64).exp();
469
470 assert_relative_eq!(final_value, exact, epsilon = 1e-3);
471 assert!(result.success);
472 }
474
475 #[test]
476 #[cfg(feature = "simd")]
477 fn test_simd_rk45_adaptive() {
478 let f = |_t: f64, y: ArrayView1<f64>| -> Array1<f64> { arr1(&[y[1], -y[0]]) };
481
482 let y0 = arr1(&[1.0, 0.0]); let t_span = [0.0, std::f64::consts::PI]; let opts = ODEOptions {
485 atol: 1e-8,
486 rtol: 1e-8,
487 h0: Some(0.1),
488 ..Default::default()
489 };
490
491 let result = simd_rk45_method(f, t_span, y0, opts).unwrap();
492
493 let final_y = result.y.last().unwrap();
495 assert_relative_eq!(final_y[0], -1.0, epsilon = 1e-6);
496 assert_relative_eq!(final_y[1], 0.0, epsilon = 1e-6);
497 assert!(result.success);
498 }
500}