scirs2_integrate/ode/utils/
simd_ops.rs1#![allow(clippy::missing_transmute_annotations)]
12#![allow(clippy::needless_range_loop)]
13
14use crate::common::IntegrateFloat;
15use crate::error::IntegrateResult;
16use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1, Zip};
17use scirs2_core::simd_ops::SimdUnifiedOps;
18
19pub struct SimdOdeOps;
21
22impl SimdOdeOps {
23 pub fn simd_axpy<F: IntegrateFloat + SimdUnifiedOps>(
25 y: &mut ArrayViewMut1<F>,
26 a: F,
27 dy: &ArrayView1<F>,
28 ) {
29 #[cfg(feature = "simd")]
31 if F::simd_available() {
32 let scaled_dy = F::simd_scalar_mul(dy, a);
34 let y_view = ArrayView1::from(&*y);
36 let result = F::simd_add(&y_view, &scaled_dy.view());
37 y.assign(&result);
39 return;
40 }
41
42 Zip::from(y).and(dy).for_each(|y_val, &dy_val| {
44 *y_val += a * dy_val;
45 });
46 }
47
48 pub fn simd_linear_combination<F: IntegrateFloat + SimdUnifiedOps>(
50 x: &ArrayView1<F>,
51 a: F,
52 y: &ArrayView1<F>,
53 b: F,
54 ) -> Array1<F> {
55 #[cfg(feature = "simd")]
56 if F::simd_available() {
57 let ax = F::simd_scalar_mul(x, a);
59 let by = F::simd_scalar_mul(y, b);
60 return F::simd_add(&ax.view(), &by.view());
61 }
62
63 let mut result = Array1::zeros(x.len());
65 Zip::from(&mut result)
66 .and(x)
67 .and(y)
68 .for_each(|r, &x_val, &y_val| {
69 *r = a * x_val + b * y_val;
70 });
71 result
72 }
73
74 pub fn simd_element_max<F: IntegrateFloat + SimdUnifiedOps>(
76 a: &ArrayView1<F>,
77 b: &ArrayView1<F>,
78 ) -> Array1<F> {
79 #[cfg(feature = "simd")]
80 if F::simd_available() {
81 return F::simd_max(a, b);
82 }
83
84 let mut result = Array1::zeros(a.len());
86 Zip::from(&mut result)
87 .and(a)
88 .and(b)
89 .for_each(|r, &a_val, &b_val| {
90 *r = a_val.max(b_val);
91 });
92 result
93 }
94
95 pub fn simd_element_min<F: IntegrateFloat + SimdUnifiedOps>(
97 a: &ArrayView1<F>,
98 b: &ArrayView1<F>,
99 ) -> Array1<F> {
100 #[cfg(feature = "simd")]
101 if F::simd_available() {
102 return F::simd_min(a, b);
103 }
104
105 let mut result = Array1::zeros(a.len());
107 Zip::from(&mut result)
108 .and(a)
109 .and(b)
110 .for_each(|r, &a_val, &b_val| {
111 *r = a_val.min(b_val);
112 });
113 result
114 }
115
116 pub fn simd_norm_l2<F: IntegrateFloat + SimdUnifiedOps>(x: &ArrayView1<F>) -> F {
118 #[cfg(feature = "simd")]
119 if F::simd_available() {
120 return F::simd_norm(x);
121 }
122
123 let mut sum = F::zero();
125 for &val in x.iter() {
126 sum += val * val;
127 }
128 sum.sqrt()
129 }
130
131 pub fn simd_norm_inf<F: IntegrateFloat + SimdUnifiedOps>(x: &ArrayView1<F>) -> F {
133 #[cfg(feature = "simd")]
134 if F::simd_available() {
135 let abs_x = F::simd_abs(x);
137 return F::simd_max_element(&abs_x.view());
138 }
139
140 let mut max_val = F::zero();
142 for &val in x.iter() {
143 let abs_val = val.abs();
144 if abs_val > max_val {
145 max_val = abs_val;
146 }
147 }
148 max_val
149 }
150
151 pub fn simd_map_scalar<F, Func>(x: &ArrayView1<F>, f: Func) -> Array1<F>
153 where
154 F: IntegrateFloat + SimdUnifiedOps,
155 Func: Fn(F) -> F,
156 {
157 let mut result = Array1::zeros(x.len());
160 Zip::from(&mut result).and(x).for_each(|r, &x_val| {
161 *r = f(x_val);
162 });
163 result
164 }
165}
166
167#[allow(dead_code)]
173pub fn simd_dense_update<F: IntegrateFloat + SimdUnifiedOps>(
174 coefficients: &[F],
175 states: &[ArrayView1<F>],
176) -> IntegrateResult<Array1<F>> {
177 if coefficients.is_empty() || states.is_empty() {
178 return Err(crate::error::IntegrateError::ValueError(
179 "Empty coefficients or states".to_string(),
180 ));
181 }
182
183 if coefficients.len() != states.len() {
184 return Err(crate::error::IntegrateError::ValueError(
185 "Coefficients and states must have the same length".to_string(),
186 ));
187 }
188
189 let n = states[0].len();
190 for state in states.iter() {
191 if state.len() != n {
192 return Err(crate::error::IntegrateError::ValueError(
193 "All states must have the same length".to_string(),
194 ));
195 }
196 }
197
198 let mut result = F::simd_scalar_mul(&states[0], coefficients[0]);
200
201 for (coeff, state) in coefficients[1..].iter().zip(&states[1..]) {
203 let term = F::simd_scalar_mul(state, *coeff);
204 result = F::simd_add(&result.view(), &term.view());
205 }
206
207 Ok(result)
208}
209
210#[allow(dead_code)]
214pub fn simd_rk_step<F: IntegrateFloat + SimdUnifiedOps>(
215 y: &ArrayView1<F>,
216 k_stages: &[Array1<F>],
217 coefficients: &[F],
218 dt: F,
219) -> IntegrateResult<Array1<F>> {
220 if coefficients.is_empty() || k_stages.is_empty() {
221 return Ok(y.to_owned());
222 }
223
224 if coefficients.len() != k_stages.len() {
225 return Err(crate::error::IntegrateError::ValueError(
226 "Coefficients and k_stages must have the same length".to_string(),
227 ));
228 }
229
230 let mut temp_state = y.to_owned();
232
233 for (coeff, k) in coefficients.iter().zip(k_stages.iter()) {
234 let scaled_k = F::simd_scalar_mul(&k.view(), *coeff * dt);
235 temp_state = F::simd_add(&temp_state.view(), &scaled_k.view());
236 }
237
238 Ok(temp_state)
239}
240
241#[allow(dead_code)]
245pub fn simd_ode_function_eval<F, Func>(
246 t: F,
247 y: &ArrayView1<F>,
248 f: &Func,
249) -> IntegrateResult<Array1<F>>
250where
251 F: IntegrateFloat + SimdUnifiedOps,
252 Func: Fn(F, &ArrayView1<F>) -> IntegrateResult<Array1<F>>,
253{
254 f(t, y)
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use scirs2_core::ndarray::array;
262
263 #[test]
264 fn test_simd_axpy() {
265 let mut y = array![1.0, 2.0, 3.0, 4.0];
266 let dy = array![0.1, 0.2, 0.3, 0.4];
267 let a = 2.0;
268
269 SimdOdeOps::simd_axpy(&mut y.view_mut(), a, &dy.view());
270
271 assert_eq!(y, array![1.2, 2.4, 3.6, 4.8]);
272 }
273
274 #[test]
275 fn test_simd_linear_combination() {
276 let x = array![1.0, 2.0, 3.0, 4.0];
277 let y = array![0.1, 0.2, 0.3, 0.4];
278 let a = 2.0;
279 let b = 3.0;
280
281 let result = SimdOdeOps::simd_linear_combination(&x.view(), a, &y.view(), b);
282
283 assert_eq!(result, array![2.3, 4.6, 6.9, 9.2]);
284 }
285
286 #[test]
287 fn test_simd_element_max() {
288 let a = array![1.0, 5.0, 3.0, 7.0];
289 let b = array![2.0, 4.0, 6.0, 1.0];
290
291 let result = SimdOdeOps::simd_element_max(&a.view(), &b.view());
292
293 assert_eq!(result, array![2.0, 5.0, 6.0, 7.0]);
294 }
295
296 #[test]
297 fn test_simd_norm_l2() {
298 let x = array![3.0, 4.0];
299 let norm = SimdOdeOps::simd_norm_l2(&x.view());
300 assert_eq!(norm, 5.0);
301 }
302
303 #[test]
304 fn test_simd_norm_inf() {
305 let x = array![-3.0, 4.0, -5.0, 2.0];
306 let norm = SimdOdeOps::simd_norm_inf(&x.view());
307 assert_eq!(norm, 5.0);
308 }
309}