Skip to main content

scirs2/
integrate.rs

1//! Python bindings for scirs2-integrate
2//!
3//! Provides numerical integration similar to scipy.integrate
4
5// Callbacks acquire the GIL via Python::attach (pyo3 0.28+) since they are invoked
6// from Rust integration routines where the caller's Python<'_> token is not in scope.
7
8use pyo3::prelude::*;
9use pyo3::types::PyDict;
10use scirs2_core::ndarray::{Array1 as Array1_17, ArrayView1};
11use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
12use scirs2_numpy::{PyArray1, PyReadonlyArray1};
13
14use scirs2_integrate::ode::{solve_ivp, ODEMethod, ODEOptions};
15use scirs2_integrate::quad::{quad, QuadOptions};
16
17// =============================================================================
18// Array-based Integration (works without callbacks)
19// =============================================================================
20
21/// Integrate using array data (y values at x points) - trapezoidal rule
22///
23/// Similar to scipy.integrate.trapezoid
24#[pyfunction]
25#[pyo3(signature = (y, x=None, dx=1.0))]
26fn trapezoid_array_py(
27    y: PyReadonlyArray1<f64>,
28    x: Option<PyReadonlyArray1<f64>>,
29    dx: f64,
30) -> PyResult<f64> {
31    let y_arr = y.as_array();
32
33    if y_arr.len() < 2 {
34        return Err(pyo3::exceptions::PyValueError::new_err(
35            "Need at least 2 points",
36        ));
37    }
38
39    let result = if let Some(x_py) = x {
40        let x_arr = x_py.as_array();
41        if x_arr.len() != y_arr.len() {
42            return Err(pyo3::exceptions::PyValueError::new_err(
43                "x and y must have same length",
44            ));
45        }
46        // Non-uniform spacing
47        let mut total = 0.0;
48        for i in 0..y_arr.len() - 1 {
49            let dx = x_arr[i + 1] - x_arr[i];
50            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
51        }
52        total
53    } else {
54        // Uniform spacing with provided dx
55        let mut total = 0.0;
56        for i in 0..y_arr.len() - 1 {
57            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
58        }
59        total
60    };
61
62    Ok(result)
63}
64
65/// Integrate using array data - Simpson's rule
66///
67/// Similar to scipy.integrate.simpson
68#[pyfunction]
69#[pyo3(signature = (y, x=None, dx=1.0))]
70fn simpson_array_py(
71    y: PyReadonlyArray1<f64>,
72    x: Option<PyReadonlyArray1<f64>>,
73    dx: f64,
74) -> PyResult<f64> {
75    let y_arr = y.as_array();
76    let n = y_arr.len();
77
78    if n < 3 {
79        return Err(pyo3::exceptions::PyValueError::new_err(
80            "Need at least 3 points",
81        ));
82    }
83
84    // Use Simpson's rule for even number of intervals, fall back to trapezoid for odd
85    let result = if let Some(x_py) = x {
86        let x_arr = x_py.as_array();
87        if x_arr.len() != y_arr.len() {
88            return Err(pyo3::exceptions::PyValueError::new_err(
89                "x and y must have same length",
90            ));
91        }
92
93        let mut total = 0.0;
94        let mut i = 0;
95        while i + 2 < n {
96            let h = (x_arr[i + 2] - x_arr[i]) / 2.0;
97            total += h / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
98            i += 2;
99        }
100        // Handle remaining interval with trapezoid
101        if i + 1 < n {
102            let h = x_arr[i + 1] - x_arr[i];
103            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * h;
104        }
105        total
106    } else {
107        let mut total = 0.0;
108        let mut i = 0;
109        while i + 2 < n {
110            total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
111            i += 2;
112        }
113        // Handle remaining interval with trapezoid
114        if i + 1 < n {
115            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
116        }
117        total
118    };
119
120    Ok(result)
121}
122
123/// Cumulative trapezoidal integration
124///
125/// Similar to scipy.integrate.cumulative_trapezoid
126#[pyfunction]
127#[pyo3(signature = (y, x=None, dx=1.0, initial=None))]
128fn cumulative_trapezoid_py(
129    py: Python,
130    y: PyReadonlyArray1<f64>,
131    x: Option<PyReadonlyArray1<f64>>,
132    dx: f64,
133    initial: Option<f64>,
134) -> PyResult<Py<PyArray1<f64>>> {
135    let y_arr = y.as_array();
136
137    if y_arr.len() < 2 {
138        return Err(pyo3::exceptions::PyValueError::new_err(
139            "Need at least 2 points",
140        ));
141    }
142
143    let n = y_arr.len();
144    let has_initial = initial.is_some();
145    let mut result = Vec::with_capacity(if has_initial { n } else { n - 1 });
146
147    if let Some(init) = initial {
148        result.push(init);
149    }
150
151    let mut cumsum = initial.unwrap_or(0.0);
152
153    if let Some(x_py) = x {
154        let x_arr = x_py.as_array();
155        for i in 0..y_arr.len() - 1 {
156            let dx_i = x_arr[i + 1] - x_arr[i];
157            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx_i;
158            result.push(cumsum);
159        }
160    } else {
161        for i in 0..y_arr.len() - 1 {
162            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
163            result.push(cumsum);
164        }
165    }
166
167    scirs_to_numpy_array1(Array1::from_vec(result), py)
168}
169
170/// Romberg integration using array data
171#[pyfunction]
172fn romberg_array_py(y: PyReadonlyArray1<f64>, dx: f64) -> PyResult<f64> {
173    let y_arr = y.as_array();
174    let n = y_arr.len();
175
176    if n < 3 {
177        return Err(pyo3::exceptions::PyValueError::new_err(
178            "Need at least 3 points",
179        ));
180    }
181
182    // Simple implementation using available data points
183    // This is essentially Simpson's rule as a good approximation
184    let mut total = 0.0;
185    let mut i = 0;
186    while i + 2 < n {
187        total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
188        i += 2;
189    }
190    if i + 1 < n {
191        total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
192    }
193
194    Ok(total)
195}
196
197// =============================================================================
198// Adaptive Quadrature
199// =============================================================================
200
201/// Adaptive quadrature integration
202///
203/// Parameters:
204/// - fun: Function to integrate
205/// - a: Lower bound
206/// - b: Upper bound
207/// - epsabs: Absolute error tolerance (default 1.49e-8)
208/// - epsrel: Relative error tolerance (default 1.49e-8)
209/// - maxiter: Maximum function evaluations (default 500)
210///
211/// Returns:
212/// - Dict with 'value' (integral), 'error' (estimated error), 'neval', 'success'
213#[pyfunction]
214#[pyo3(signature = (fun, a, b, epsabs=1.49e-8, epsrel=1.49e-8, maxiter=500))]
215fn quad_py(
216    py: Python,
217    fun: &Bound<'_, PyAny>,
218    a: f64,
219    b: f64,
220    epsabs: f64,
221    epsrel: f64,
222    maxiter: usize,
223) -> PyResult<Py<PyAny>> {
224    let fun_clone = fun.clone().unbind();
225    let f = |x: f64| -> f64 {
226        Python::attach(|py| {
227            let result = fun_clone
228                .bind(py)
229                .call1((x,))
230                .expect("Failed to call function");
231            result.extract().expect("Failed to extract result")
232        })
233    };
234
235    let options = QuadOptions {
236        abs_tol: epsabs,
237        rel_tol: epsrel,
238        max_evals: maxiter,
239        ..Default::default()
240    };
241
242    let result = quad(f, a, b, Some(options))
243        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245    let dict = PyDict::new(py);
246    dict.set_item("value", result.value)?;
247    dict.set_item("error", result.abs_error)?;
248    dict.set_item("neval", result.n_evals)?;
249    dict.set_item("success", result.converged)?;
250
251    Ok(dict.into())
252}
253
254// =============================================================================
255// ODE Solvers
256// =============================================================================
257
258/// Solve an initial value problem for a system of ODEs
259///
260/// Parameters:
261/// - fun: Function computing dy/dt = f(t, y)
262/// - t_span: Tuple (t0, tf) for integration interval
263/// - y0: Initial state
264/// - method: 'RK45' (default), 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'
265/// - rtol: Relative tolerance (default 1e-3)
266/// - atol: Absolute tolerance (default 1e-6)
267/// - max_step: Maximum step size (optional)
268///
269/// Returns:
270/// - Dict with 't' (times), 'y' (solutions), 'nfev', 'success', 'message'
271#[pyfunction]
272#[pyo3(signature = (fun, t_span, y0, method="RK45", rtol=1e-3, atol=1e-6, max_step=None))]
273fn solve_ivp_py(
274    py: Python,
275    fun: &Bound<'_, PyAny>,
276    t_span: (f64, f64),
277    y0: Vec<f64>,
278    method: &str,
279    rtol: f64,
280    atol: f64,
281    max_step: Option<f64>,
282) -> PyResult<Py<PyAny>> {
283    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
284    let f = move |t: f64, y: ArrayView1<f64>| -> Array1_17<f64> {
285        let fun_clone = fun_arc.clone();
286        Python::attach(|py| {
287            let y_vec: Vec<f64> = y.to_vec();
288            let result = fun_clone
289                .bind(py)
290                .call1((t, y_vec))
291                .expect("Failed to call ODE function");
292            let result_vec: Vec<f64> = result.extract().expect("Failed to extract result");
293            Array1_17::from_vec(result_vec)
294        })
295    };
296
297    let ode_method = match method.to_uppercase().as_str() {
298        "EULER" => ODEMethod::Euler,
299        "RK4" => ODEMethod::RK4,
300        "RK23" => ODEMethod::RK23,
301        "RK45" => ODEMethod::RK45,
302        "DOP853" => ODEMethod::DOP853,
303        "BDF" => ODEMethod::Bdf,
304        "RADAU" => ODEMethod::Radau,
305        "LSODA" => ODEMethod::LSODA,
306        _ => ODEMethod::RK45,
307    };
308
309    let options = ODEOptions {
310        method: ode_method,
311        rtol,
312        atol,
313        max_step,
314        ..Default::default()
315    };
316
317    let y0_arr = Array1_17::from_vec(y0);
318    let result = solve_ivp(f, [t_span.0, t_span.1], y0_arr, Some(options))
319        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
320
321    // Convert results to Python
322    let t_vec: Vec<f64> = result.t.to_vec();
323
324    // Convert y (Vec<Array1>) to 2D array
325    let n_points = result.y.len();
326    let n_dim = if n_points > 0 { result.y[0].len() } else { 0 };
327    let mut y_flat = Vec::with_capacity(n_points * n_dim);
328    for arr in &result.y {
329        for &val in arr.iter() {
330            y_flat.push(val);
331        }
332    }
333
334    let dict = PyDict::new(py);
335    dict.set_item("t", scirs_to_numpy_array1(Array1::from_vec(t_vec), py)?)?;
336
337    // Create 2D array for y
338    let y_arr = scirs2_core::python::numpy_compat::Array2::from_shape_vec((n_dim, n_points), {
339        let mut transposed = Vec::with_capacity(n_points * n_dim);
340        for j in 0..n_dim {
341            for i in 0..n_points {
342                transposed.push(y_flat[i * n_dim + j]);
343            }
344        }
345        transposed
346    })
347    .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
348    dict.set_item(
349        "y",
350        scirs2_core::python::numpy_compat::scirs_to_numpy_array2(y_arr, py)?,
351    )?;
352
353    dict.set_item("nfev", result.n_eval)?;
354    dict.set_item("success", result.success)?;
355    dict.set_item("message", result.message)?;
356
357    Ok(dict.into())
358}
359
360/// Python module registration
361pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
362    m.add_function(wrap_pyfunction!(trapezoid_array_py, m)?)?;
363    m.add_function(wrap_pyfunction!(simpson_array_py, m)?)?;
364    m.add_function(wrap_pyfunction!(cumulative_trapezoid_py, m)?)?;
365    m.add_function(wrap_pyfunction!(romberg_array_py, m)?)?;
366    m.add_function(wrap_pyfunction!(quad_py, m)?)?;
367    m.add_function(wrap_pyfunction!(solve_ivp_py, m)?)?;
368
369    Ok(())
370}