Skip to main content

scirs2_integrate/ode/utils/
simd_ops.rs

1//! SIMD-optimized operations for ODE solvers
2//!
3//! This module provides SIMD-accelerated implementations of common operations
4//! used in ODE solving, such as vector arithmetic, norm calculations, and
5//! element-wise function evaluation. These optimizations can provide significant
6//! performance improvements for large systems of ODEs.
7//!
8//! All SIMD operations are delegated to scirs2-core's unified SIMD abstraction layer
9//! in compliance with the project-wide SIMD policy.
10
11#![allow(clippy::missing_transmute_annotations)]
12#![allow(clippy::needless_range_loop)]
13
14use crate::common::IntegrateFloat;
15use crate::error::IntegrateResult;
16use scirs2_core::ndarray::{Array1, ArrayView1, ArrayViewMut1, Zip};
17use scirs2_core::simd_ops::SimdUnifiedOps;
18
19/// SIMD-optimized ODE operations
20pub struct SimdOdeOps;
21
22impl SimdOdeOps {
23    /// Compute y = y + a * dy using SIMD operations
24    pub fn simd_axpy<F: IntegrateFloat + SimdUnifiedOps>(
25        y: &mut ArrayViewMut1<F>,
26        a: F,
27        dy: &ArrayView1<F>,
28    ) {
29        // Use core SIMD operations: y = y + a * dy
30        #[cfg(feature = "simd")]
31        if F::simd_available() {
32            // Compute a * dy
33            let scaled_dy = F::simd_scalar_mul(dy, a);
34            // Add to y
35            let y_view = ArrayView1::from(&*y);
36            let result = F::simd_add(&y_view, &scaled_dy.view());
37            // Copy result back to y
38            y.assign(&result);
39            return;
40        }
41
42        // Fallback implementation
43        Zip::from(y).and(dy).for_each(|y_val, &dy_val| {
44            *y_val += a * dy_val;
45        });
46    }
47
48    /// Compute linear combination: result = a*x + b*y using SIMD
49    pub fn simd_linear_combination<F: IntegrateFloat + SimdUnifiedOps>(
50        x: &ArrayView1<F>,
51        a: F,
52        y: &ArrayView1<F>,
53        b: F,
54    ) -> Array1<F> {
55        #[cfg(feature = "simd")]
56        if F::simd_available() {
57            // Compute a*x and b*y, then add them
58            let ax = F::simd_scalar_mul(x, a);
59            let by = F::simd_scalar_mul(y, b);
60            return F::simd_add(&ax.view(), &by.view());
61        }
62
63        // Fallback implementation
64        let mut result = Array1::zeros(x.len());
65        Zip::from(&mut result)
66            .and(x)
67            .and(y)
68            .for_each(|r, &x_val, &y_val| {
69                *r = a * x_val + b * y_val;
70            });
71        result
72    }
73
74    /// Compute element-wise maximum using SIMD
75    pub fn simd_element_max<F: IntegrateFloat + SimdUnifiedOps>(
76        a: &ArrayView1<F>,
77        b: &ArrayView1<F>,
78    ) -> Array1<F> {
79        #[cfg(feature = "simd")]
80        if F::simd_available() {
81            return F::simd_max(a, b);
82        }
83
84        // Fallback implementation
85        let mut result = Array1::zeros(a.len());
86        Zip::from(&mut result)
87            .and(a)
88            .and(b)
89            .for_each(|r, &a_val, &b_val| {
90                *r = a_val.max(b_val);
91            });
92        result
93    }
94
95    /// Compute element-wise minimum using SIMD
96    pub fn simd_element_min<F: IntegrateFloat + SimdUnifiedOps>(
97        a: &ArrayView1<F>,
98        b: &ArrayView1<F>,
99    ) -> Array1<F> {
100        #[cfg(feature = "simd")]
101        if F::simd_available() {
102            return F::simd_min(a, b);
103        }
104
105        // Fallback implementation
106        let mut result = Array1::zeros(a.len());
107        Zip::from(&mut result)
108            .and(a)
109            .and(b)
110            .for_each(|r, &a_val, &b_val| {
111                *r = a_val.min(b_val);
112            });
113        result
114    }
115
116    /// Compute L2 norm using SIMD
117    pub fn simd_norm_l2<F: IntegrateFloat + SimdUnifiedOps>(x: &ArrayView1<F>) -> F {
118        #[cfg(feature = "simd")]
119        if F::simd_available() {
120            return F::simd_norm(x);
121        }
122
123        // Fallback implementation
124        let mut sum = F::zero();
125        for &val in x.iter() {
126            sum += val * val;
127        }
128        sum.sqrt()
129    }
130
131    /// Compute infinity norm using SIMD
132    pub fn simd_norm_inf<F: IntegrateFloat + SimdUnifiedOps>(x: &ArrayView1<F>) -> F {
133        #[cfg(feature = "simd")]
134        if F::simd_available() {
135            // Use SIMD to compute absolute values and find maximum
136            let abs_x = F::simd_abs(x);
137            return F::simd_max_element(&abs_x.view());
138        }
139
140        // Fallback implementation
141        let mut max_val = F::zero();
142        for &val in x.iter() {
143            let abs_val = val.abs();
144            if abs_val > max_val {
145                max_val = abs_val;
146            }
147        }
148        max_val
149    }
150
151    /// Apply scalar function element-wise using SIMD where possible
152    pub fn simd_map_scalar<F, Func>(x: &ArrayView1<F>, f: Func) -> Array1<F>
153    where
154        F: IntegrateFloat + SimdUnifiedOps,
155        Func: Fn(F) -> F,
156    {
157        // Note: Generic scalar functions cannot be vectorized directly
158        // This is kept for API compatibility but doesn't use SIMD
159        let mut result = Array1::zeros(x.len());
160        Zip::from(&mut result).and(x).for_each(|r, &x_val| {
161            *r = f(x_val);
162        });
163        result
164    }
165}
166
167/// SIMD-optimized dense update for ODE solvers
168///
169/// Computes: y = a0 * y0 + a1 * y1 + a2 * y2 + ... + an * yn
170///
171/// This is a common operation in multistage ODE methods like Runge-Kutta.
172#[allow(dead_code)]
173pub fn simd_dense_update<F: IntegrateFloat + SimdUnifiedOps>(
174    coefficients: &[F],
175    states: &[ArrayView1<F>],
176) -> IntegrateResult<Array1<F>> {
177    if coefficients.is_empty() || states.is_empty() {
178        return Err(crate::error::IntegrateError::ValueError(
179            "Empty coefficients or states".to_string(),
180        ));
181    }
182
183    if coefficients.len() != states.len() {
184        return Err(crate::error::IntegrateError::ValueError(
185            "Coefficients and states must have the same length".to_string(),
186        ));
187    }
188
189    let n = states[0].len();
190    for state in states.iter() {
191        if state.len() != n {
192            return Err(crate::error::IntegrateError::ValueError(
193                "All states must have the same length".to_string(),
194            ));
195        }
196    }
197
198    // Start with the first term
199    let mut result = F::simd_scalar_mul(&states[0], coefficients[0]);
200
201    // Add remaining terms using SIMD FMA when available
202    for (coeff, state) in coefficients[1..].iter().zip(&states[1..]) {
203        let term = F::simd_scalar_mul(state, *coeff);
204        result = F::simd_add(&result.view(), &term.view());
205    }
206
207    Ok(result)
208}
209
210/// SIMD-optimized Runge-Kutta step evaluation
211///
212/// Evaluates: k_new = f(t + c*dt, y + sum(a_ij * k_j * dt))
213#[allow(dead_code)]
214pub fn simd_rk_step<F: IntegrateFloat + SimdUnifiedOps>(
215    y: &ArrayView1<F>,
216    k_stages: &[Array1<F>],
217    coefficients: &[F],
218    dt: F,
219) -> IntegrateResult<Array1<F>> {
220    if coefficients.is_empty() || k_stages.is_empty() {
221        return Ok(y.to_owned());
222    }
223
224    if coefficients.len() != k_stages.len() {
225        return Err(crate::error::IntegrateError::ValueError(
226            "Coefficients and k_stages must have the same length".to_string(),
227        ));
228    }
229
230    // Compute y + sum(a_ij * k_j * dt) using SIMD operations
231    let mut temp_state = y.to_owned();
232
233    for (coeff, k) in coefficients.iter().zip(k_stages.iter()) {
234        let scaled_k = F::simd_scalar_mul(&k.view(), *coeff * dt);
235        temp_state = F::simd_add(&temp_state.view(), &scaled_k.view());
236    }
237
238    Ok(temp_state)
239}
240
241/// SIMD-optimized function evaluation for systems of ODEs
242///
243/// Evaluates multiple ODE functions in parallel when possible.
244#[allow(dead_code)]
245pub fn simd_ode_function_eval<F, Func>(
246    t: F,
247    y: &ArrayView1<F>,
248    f: &Func,
249) -> IntegrateResult<Array1<F>>
250where
251    F: IntegrateFloat + SimdUnifiedOps,
252    Func: Fn(F, &ArrayView1<F>) -> IntegrateResult<Array1<F>>,
253{
254    // Direct function evaluation - SIMD optimizations would be within the function itself
255    f(t, y)
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use scirs2_core::ndarray::array;
262
263    #[test]
264    fn test_simd_axpy() {
265        let mut y = array![1.0, 2.0, 3.0, 4.0];
266        let dy = array![0.1, 0.2, 0.3, 0.4];
267        let a = 2.0;
268
269        SimdOdeOps::simd_axpy(&mut y.view_mut(), a, &dy.view());
270
271        assert_eq!(y, array![1.2, 2.4, 3.6, 4.8]);
272    }
273
274    #[test]
275    fn test_simd_linear_combination() {
276        let x = array![1.0, 2.0, 3.0, 4.0];
277        let y = array![0.1, 0.2, 0.3, 0.4];
278        let a = 2.0;
279        let b = 3.0;
280
281        let result = SimdOdeOps::simd_linear_combination(&x.view(), a, &y.view(), b);
282
283        assert_eq!(result, array![2.3, 4.6, 6.9, 9.2]);
284    }
285
286    #[test]
287    fn test_simd_element_max() {
288        let a = array![1.0, 5.0, 3.0, 7.0];
289        let b = array![2.0, 4.0, 6.0, 1.0];
290
291        let result = SimdOdeOps::simd_element_max(&a.view(), &b.view());
292
293        assert_eq!(result, array![2.0, 5.0, 6.0, 7.0]);
294    }
295
296    #[test]
297    fn test_simd_norm_l2() {
298        let x = array![3.0, 4.0];
299        let norm = SimdOdeOps::simd_norm_l2(&x.view());
300        assert_eq!(norm, 5.0);
301    }
302
303    #[test]
304    fn test_simd_norm_inf() {
305        let x = array![-3.0, 4.0, -5.0, 2.0];
306        let norm = SimdOdeOps::simd_norm_inf(&x.view());
307        assert_eq!(norm, 5.0);
308    }
309}