Skip to main content

scirs2/
integrate.rs

1//! Python bindings for scirs2-integrate
2//!
3//! Provides numerical integration similar to scipy.integrate
4
5// Callbacks acquire the GIL via Python::attach (pyo3 0.28+) since they are invoked
6// from Rust integration routines where the caller's Python<'_> token is not in scope.
7
8use pyo3::prelude::*;
9use pyo3::types::PyDict;
10use scirs2_core::ndarray::{Array1 as Array1_17, ArrayView1};
11use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
12use scirs2_numpy::{PyArray1, PyReadonlyArray1};
13
14use scirs2_integrate::ode::{solve_ivp, ODEMethod, ODEOptions};
15use scirs2_integrate::quad::{quad, QuadOptions};
16
17// =============================================================================
18// Array-based Integration (works without callbacks)
19// =============================================================================
20
21/// Integrate using array data (y values at x points) - trapezoidal rule
22///
23/// Similar to scipy.integrate.trapezoid
24#[pyfunction]
25#[pyo3(signature = (y, x=None, dx=1.0))]
26fn trapezoid_array_py(
27    y: PyReadonlyArray1<f64>,
28    x: Option<PyReadonlyArray1<f64>>,
29    dx: f64,
30) -> PyResult<f64> {
31    let y_arr = y.as_array();
32
33    if y_arr.len() < 2 {
34        return Err(pyo3::exceptions::PyValueError::new_err(
35            "Need at least 2 points",
36        ));
37    }
38
39    let result = if let Some(x_py) = x {
40        let x_arr = x_py.as_array();
41        if x_arr.len() != y_arr.len() {
42            return Err(pyo3::exceptions::PyValueError::new_err(
43                "x and y must have same length",
44            ));
45        }
46        // Non-uniform spacing
47        let mut total = 0.0;
48        for i in 0..y_arr.len() - 1 {
49            let dx = x_arr[i + 1] - x_arr[i];
50            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
51        }
52        total
53    } else {
54        // Uniform spacing with provided dx
55        let mut total = 0.0;
56        for i in 0..y_arr.len() - 1 {
57            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
58        }
59        total
60    };
61
62    Ok(result)
63}
64
65/// Integrate using array data - Simpson's rule
66///
67/// Similar to scipy.integrate.simpson
68#[pyfunction]
69#[pyo3(signature = (y, x=None, dx=1.0))]
70fn simpson_array_py(
71    y: PyReadonlyArray1<f64>,
72    x: Option<PyReadonlyArray1<f64>>,
73    dx: f64,
74) -> PyResult<f64> {
75    let y_arr = y.as_array();
76    let n = y_arr.len();
77
78    if n < 3 {
79        return Err(pyo3::exceptions::PyValueError::new_err(
80            "Need at least 3 points",
81        ));
82    }
83
84    // Use Simpson's rule for even number of intervals, fall back to trapezoid for odd
85    let result = if let Some(x_py) = x {
86        let x_arr = x_py.as_array();
87        if x_arr.len() != y_arr.len() {
88            return Err(pyo3::exceptions::PyValueError::new_err(
89                "x and y must have same length",
90            ));
91        }
92
93        let mut total = 0.0;
94        let mut i = 0;
95        while i + 2 < n {
96            let h = (x_arr[i + 2] - x_arr[i]) / 2.0;
97            total += h / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
98            i += 2;
99        }
100        // Handle remaining interval with trapezoid
101        if i + 1 < n {
102            let h = x_arr[i + 1] - x_arr[i];
103            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * h;
104        }
105        total
106    } else {
107        let mut total = 0.0;
108        let mut i = 0;
109        while i + 2 < n {
110            total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
111            i += 2;
112        }
113        // Handle remaining interval with trapezoid
114        if i + 1 < n {
115            total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
116        }
117        total
118    };
119
120    Ok(result)
121}
122
123/// Cumulative trapezoidal integration
124///
125/// Similar to scipy.integrate.cumulative_trapezoid
126#[pyfunction]
127#[pyo3(signature = (y, x=None, dx=1.0, initial=None))]
128fn cumulative_trapezoid_py(
129    py: Python,
130    y: PyReadonlyArray1<f64>,
131    x: Option<PyReadonlyArray1<f64>>,
132    dx: f64,
133    initial: Option<f64>,
134) -> PyResult<Py<PyArray1<f64>>> {
135    let y_arr = y.as_array();
136
137    if y_arr.len() < 2 {
138        return Err(pyo3::exceptions::PyValueError::new_err(
139            "Need at least 2 points",
140        ));
141    }
142
143    let n = y_arr.len();
144    let has_initial = initial.is_some();
145    let mut result = Vec::with_capacity(if has_initial { n } else { n - 1 });
146
147    if let Some(init) = initial {
148        result.push(init);
149    }
150
151    let mut cumsum = initial.unwrap_or(0.0);
152
153    if let Some(x_py) = x {
154        let x_arr = x_py.as_array();
155        for i in 0..y_arr.len() - 1 {
156            let dx_i = x_arr[i + 1] - x_arr[i];
157            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx_i;
158            result.push(cumsum);
159        }
160    } else {
161        for i in 0..y_arr.len() - 1 {
162            cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
163            result.push(cumsum);
164        }
165    }
166
167    scirs_to_numpy_array1(Array1::from_vec(result), py)
168}
169
170/// Romberg integration using array data
171#[pyfunction]
172fn romberg_array_py(y: PyReadonlyArray1<f64>, dx: f64) -> PyResult<f64> {
173    let y_arr = y.as_array();
174    let n = y_arr.len();
175
176    if n < 3 {
177        return Err(pyo3::exceptions::PyValueError::new_err(
178            "Need at least 3 points",
179        ));
180    }
181
182    // Simple implementation using available data points
183    // This is essentially Simpson's rule as a good approximation
184    let mut total = 0.0;
185    let mut i = 0;
186    while i + 2 < n {
187        total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
188        i += 2;
189    }
190    if i + 1 < n {
191        total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
192    }
193
194    Ok(total)
195}
196
197// =============================================================================
198// Adaptive Quadrature
199// =============================================================================
200
201/// Adaptive quadrature integration
202///
203/// Parameters:
204/// - fun: Function to integrate
205/// - a: Lower bound
206/// - b: Upper bound
207/// - epsabs: Absolute error tolerance (default 1.49e-8)
208/// - epsrel: Relative error tolerance (default 1.49e-8)
209/// - maxiter: Maximum function evaluations (default 500)
210///
211/// Returns:
212/// - Dict with 'value' (integral), 'error' (estimated error), 'neval', 'success'
213#[pyfunction]
214#[pyo3(signature = (fun, a, b, epsabs=1.49e-8, epsrel=1.49e-8, maxiter=500))]
215fn quad_py(
216    py: Python,
217    fun: &Bound<'_, PyAny>,
218    a: f64,
219    b: f64,
220    epsabs: f64,
221    epsrel: f64,
222    maxiter: usize,
223) -> PyResult<Py<PyAny>> {
224    let fun_clone = fun.clone().unbind();
225    let f = |x: f64| -> f64 {
226        Python::attach(|py| {
227            let result = fun_clone
228                .bind(py)
229                .call1((x,))
230                .unwrap_or_else(|_| py.None().into_bound(py));
231            result.extract::<f64>().unwrap_or(f64::NAN)
232        })
233    };
234
235    let options = QuadOptions {
236        abs_tol: epsabs,
237        rel_tol: epsrel,
238        max_evals: maxiter,
239        ..Default::default()
240    };
241
242    let result = quad(f, a, b, Some(options))
243        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245    let dict = PyDict::new(py);
246    dict.set_item("value", result.value)?;
247    dict.set_item("error", result.abs_error)?;
248    dict.set_item("neval", result.n_evals)?;
249    dict.set_item("success", result.converged)?;
250
251    Ok(dict.into())
252}
253
254// =============================================================================
255// ODE Solvers
256// =============================================================================
257
258/// Solve an initial value problem for a system of ODEs
259///
260/// Parameters:
261/// - fun: Function computing dy/dt = f(t, y)
262/// - t_span: Tuple (t0, tf) for integration interval
263/// - y0: Initial state
264/// - method: 'RK45' (default), 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'
265/// - rtol: Relative tolerance (default 1e-3)
266/// - atol: Absolute tolerance (default 1e-6)
267/// - max_step: Maximum step size (optional)
268///
269/// Returns:
270/// - Dict with 't' (times), 'y' (solutions), 'nfev', 'success', 'message'
271#[pyfunction]
272#[pyo3(signature = (fun, t_span, y0, method="RK45", rtol=1e-3, atol=1e-6, max_step=None))]
273fn solve_ivp_py(
274    py: Python,
275    fun: &Bound<'_, PyAny>,
276    t_span: (f64, f64),
277    y0: Vec<f64>,
278    method: &str,
279    rtol: f64,
280    atol: f64,
281    max_step: Option<f64>,
282) -> PyResult<Py<PyAny>> {
283    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
284    let n_dim = y0.len();
285    let f = move |t: f64, y: ArrayView1<f64>| -> Array1_17<f64> {
286        let fun_clone = fun_arc.clone();
287        Python::attach(|py| {
288            let y_vec: Vec<f64> = y.to_vec();
289            match fun_clone.bind(py).call1((t, y_vec)) {
290                Ok(result) => match result.extract::<Vec<f64>>() {
291                    Ok(v) => Array1_17::from_vec(v),
292                    Err(_) => Array1_17::from_elem(n_dim, f64::NAN),
293                },
294                Err(_) => Array1_17::from_elem(n_dim, f64::NAN),
295            }
296        })
297    };
298
299    let ode_method = match method.to_uppercase().as_str() {
300        "EULER" => ODEMethod::Euler,
301        "RK4" => ODEMethod::RK4,
302        "RK23" => ODEMethod::RK23,
303        "RK45" => ODEMethod::RK45,
304        "DOP853" => ODEMethod::DOP853,
305        "BDF" => ODEMethod::Bdf,
306        "RADAU" => ODEMethod::Radau,
307        "LSODA" => ODEMethod::LSODA,
308        _ => ODEMethod::RK45,
309    };
310
311    let options = ODEOptions {
312        method: ode_method,
313        rtol,
314        atol,
315        max_step,
316        ..Default::default()
317    };
318
319    let y0_arr = Array1_17::from_vec(y0);
320    let result = solve_ivp(f, [t_span.0, t_span.1], y0_arr, Some(options))
321        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
322
323    // Convert results to Python
324    let t_vec: Vec<f64> = result.t.to_vec();
325
326    // Convert y (Vec<Array1>) to 2D array
327    let n_points = result.y.len();
328    let n_dim = if n_points > 0 { result.y[0].len() } else { 0 };
329    let mut y_flat = Vec::with_capacity(n_points * n_dim);
330    for arr in &result.y {
331        for &val in arr.iter() {
332            y_flat.push(val);
333        }
334    }
335
336    let dict = PyDict::new(py);
337    dict.set_item("t", scirs_to_numpy_array1(Array1::from_vec(t_vec), py)?)?;
338
339    // Create 2D array for y
340    let y_arr = scirs2_core::python::numpy_compat::Array2::from_shape_vec((n_dim, n_points), {
341        let mut transposed = Vec::with_capacity(n_points * n_dim);
342        for j in 0..n_dim {
343            for i in 0..n_points {
344                transposed.push(y_flat[i * n_dim + j]);
345            }
346        }
347        transposed
348    })
349    .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
350    dict.set_item(
351        "y",
352        scirs2_core::python::numpy_compat::scirs_to_numpy_array2(y_arr, py)?,
353    )?;
354
355    dict.set_item("nfev", result.n_eval)?;
356    dict.set_item("success", result.success)?;
357    dict.set_item("message", result.message)?;
358
359    Ok(dict.into())
360}
361
362/// Python module registration
363pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
364    m.add_function(wrap_pyfunction!(trapezoid_array_py, m)?)?;
365    m.add_function(wrap_pyfunction!(simpson_array_py, m)?)?;
366    m.add_function(wrap_pyfunction!(cumulative_trapezoid_py, m)?)?;
367    m.add_function(wrap_pyfunction!(romberg_array_py, m)?)?;
368    m.add_function(wrap_pyfunction!(quad_py, m)?)?;
369    m.add_function(wrap_pyfunction!(solve_ivp_py, m)?)?;
370
371    Ok(())
372}