1use pyo3::prelude::*;
9use pyo3::types::PyDict;
10use scirs2_core::ndarray::{Array1 as Array1_17, ArrayView1};
11use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
12use scirs2_numpy::{PyArray1, PyReadonlyArray1};
13
14use scirs2_integrate::ode::{solve_ivp, ODEMethod, ODEOptions};
15use scirs2_integrate::quad::{quad, QuadOptions};
16
17#[pyfunction]
25#[pyo3(signature = (y, x=None, dx=1.0))]
26fn trapezoid_array_py(
27 y: PyReadonlyArray1<f64>,
28 x: Option<PyReadonlyArray1<f64>>,
29 dx: f64,
30) -> PyResult<f64> {
31 let y_arr = y.as_array();
32
33 if y_arr.len() < 2 {
34 return Err(pyo3::exceptions::PyValueError::new_err(
35 "Need at least 2 points",
36 ));
37 }
38
39 let result = if let Some(x_py) = x {
40 let x_arr = x_py.as_array();
41 if x_arr.len() != y_arr.len() {
42 return Err(pyo3::exceptions::PyValueError::new_err(
43 "x and y must have same length",
44 ));
45 }
46 let mut total = 0.0;
48 for i in 0..y_arr.len() - 1 {
49 let dx = x_arr[i + 1] - x_arr[i];
50 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
51 }
52 total
53 } else {
54 let mut total = 0.0;
56 for i in 0..y_arr.len() - 1 {
57 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
58 }
59 total
60 };
61
62 Ok(result)
63}
64
65#[pyfunction]
69#[pyo3(signature = (y, x=None, dx=1.0))]
70fn simpson_array_py(
71 y: PyReadonlyArray1<f64>,
72 x: Option<PyReadonlyArray1<f64>>,
73 dx: f64,
74) -> PyResult<f64> {
75 let y_arr = y.as_array();
76 let n = y_arr.len();
77
78 if n < 3 {
79 return Err(pyo3::exceptions::PyValueError::new_err(
80 "Need at least 3 points",
81 ));
82 }
83
84 let result = if let Some(x_py) = x {
86 let x_arr = x_py.as_array();
87 if x_arr.len() != y_arr.len() {
88 return Err(pyo3::exceptions::PyValueError::new_err(
89 "x and y must have same length",
90 ));
91 }
92
93 let mut total = 0.0;
94 let mut i = 0;
95 while i + 2 < n {
96 let h = (x_arr[i + 2] - x_arr[i]) / 2.0;
97 total += h / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
98 i += 2;
99 }
100 if i + 1 < n {
102 let h = x_arr[i + 1] - x_arr[i];
103 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * h;
104 }
105 total
106 } else {
107 let mut total = 0.0;
108 let mut i = 0;
109 while i + 2 < n {
110 total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
111 i += 2;
112 }
113 if i + 1 < n {
115 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
116 }
117 total
118 };
119
120 Ok(result)
121}
122
123#[pyfunction]
127#[pyo3(signature = (y, x=None, dx=1.0, initial=None))]
128fn cumulative_trapezoid_py(
129 py: Python,
130 y: PyReadonlyArray1<f64>,
131 x: Option<PyReadonlyArray1<f64>>,
132 dx: f64,
133 initial: Option<f64>,
134) -> PyResult<Py<PyArray1<f64>>> {
135 let y_arr = y.as_array();
136
137 if y_arr.len() < 2 {
138 return Err(pyo3::exceptions::PyValueError::new_err(
139 "Need at least 2 points",
140 ));
141 }
142
143 let n = y_arr.len();
144 let has_initial = initial.is_some();
145 let mut result = Vec::with_capacity(if has_initial { n } else { n - 1 });
146
147 if let Some(init) = initial {
148 result.push(init);
149 }
150
151 let mut cumsum = initial.unwrap_or(0.0);
152
153 if let Some(x_py) = x {
154 let x_arr = x_py.as_array();
155 for i in 0..y_arr.len() - 1 {
156 let dx_i = x_arr[i + 1] - x_arr[i];
157 cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx_i;
158 result.push(cumsum);
159 }
160 } else {
161 for i in 0..y_arr.len() - 1 {
162 cumsum += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
163 result.push(cumsum);
164 }
165 }
166
167 scirs_to_numpy_array1(Array1::from_vec(result), py)
168}
169
170#[pyfunction]
172fn romberg_array_py(y: PyReadonlyArray1<f64>, dx: f64) -> PyResult<f64> {
173 let y_arr = y.as_array();
174 let n = y_arr.len();
175
176 if n < 3 {
177 return Err(pyo3::exceptions::PyValueError::new_err(
178 "Need at least 3 points",
179 ));
180 }
181
182 let mut total = 0.0;
185 let mut i = 0;
186 while i + 2 < n {
187 total += dx / 3.0 * (y_arr[i] + 4.0 * y_arr[i + 1] + y_arr[i + 2]);
188 i += 2;
189 }
190 if i + 1 < n {
191 total += 0.5 * (y_arr[i] + y_arr[i + 1]) * dx;
192 }
193
194 Ok(total)
195}
196
197#[pyfunction]
214#[pyo3(signature = (fun, a, b, epsabs=1.49e-8, epsrel=1.49e-8, maxiter=500))]
215fn quad_py(
216 py: Python,
217 fun: &Bound<'_, PyAny>,
218 a: f64,
219 b: f64,
220 epsabs: f64,
221 epsrel: f64,
222 maxiter: usize,
223) -> PyResult<Py<PyAny>> {
224 let fun_clone = fun.clone().unbind();
225 let f = |x: f64| -> f64 {
226 Python::attach(|py| {
227 let result = fun_clone
228 .bind(py)
229 .call1((x,))
230 .expect("Failed to call function");
231 result.extract().expect("Failed to extract result")
232 })
233 };
234
235 let options = QuadOptions {
236 abs_tol: epsabs,
237 rel_tol: epsrel,
238 max_evals: maxiter,
239 ..Default::default()
240 };
241
242 let result = quad(f, a, b, Some(options))
243 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
244
245 let dict = PyDict::new(py);
246 dict.set_item("value", result.value)?;
247 dict.set_item("error", result.abs_error)?;
248 dict.set_item("neval", result.n_evals)?;
249 dict.set_item("success", result.converged)?;
250
251 Ok(dict.into())
252}
253
254#[pyfunction]
272#[pyo3(signature = (fun, t_span, y0, method="RK45", rtol=1e-3, atol=1e-6, max_step=None))]
273fn solve_ivp_py(
274 py: Python,
275 fun: &Bound<'_, PyAny>,
276 t_span: (f64, f64),
277 y0: Vec<f64>,
278 method: &str,
279 rtol: f64,
280 atol: f64,
281 max_step: Option<f64>,
282) -> PyResult<Py<PyAny>> {
283 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
284 let f = move |t: f64, y: ArrayView1<f64>| -> Array1_17<f64> {
285 let fun_clone = fun_arc.clone();
286 Python::attach(|py| {
287 let y_vec: Vec<f64> = y.to_vec();
288 let result = fun_clone
289 .bind(py)
290 .call1((t, y_vec))
291 .expect("Failed to call ODE function");
292 let result_vec: Vec<f64> = result.extract().expect("Failed to extract result");
293 Array1_17::from_vec(result_vec)
294 })
295 };
296
297 let ode_method = match method.to_uppercase().as_str() {
298 "EULER" => ODEMethod::Euler,
299 "RK4" => ODEMethod::RK4,
300 "RK23" => ODEMethod::RK23,
301 "RK45" => ODEMethod::RK45,
302 "DOP853" => ODEMethod::DOP853,
303 "BDF" => ODEMethod::Bdf,
304 "RADAU" => ODEMethod::Radau,
305 "LSODA" => ODEMethod::LSODA,
306 _ => ODEMethod::RK45,
307 };
308
309 let options = ODEOptions {
310 method: ode_method,
311 rtol,
312 atol,
313 max_step,
314 ..Default::default()
315 };
316
317 let y0_arr = Array1_17::from_vec(y0);
318 let result = solve_ivp(f, [t_span.0, t_span.1], y0_arr, Some(options))
319 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
320
321 let t_vec: Vec<f64> = result.t.to_vec();
323
324 let n_points = result.y.len();
326 let n_dim = if n_points > 0 { result.y[0].len() } else { 0 };
327 let mut y_flat = Vec::with_capacity(n_points * n_dim);
328 for arr in &result.y {
329 for &val in arr.iter() {
330 y_flat.push(val);
331 }
332 }
333
334 let dict = PyDict::new(py);
335 dict.set_item("t", scirs_to_numpy_array1(Array1::from_vec(t_vec), py)?)?;
336
337 let y_arr = scirs2_core::python::numpy_compat::Array2::from_shape_vec((n_dim, n_points), {
339 let mut transposed = Vec::with_capacity(n_points * n_dim);
340 for j in 0..n_dim {
341 for i in 0..n_points {
342 transposed.push(y_flat[i * n_dim + j]);
343 }
344 }
345 transposed
346 })
347 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
348 dict.set_item(
349 "y",
350 scirs2_core::python::numpy_compat::scirs_to_numpy_array2(y_arr, py)?,
351 )?;
352
353 dict.set_item("nfev", result.n_eval)?;
354 dict.set_item("success", result.success)?;
355 dict.set_item("message", result.message)?;
356
357 Ok(dict.into())
358}
359
360pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
362 m.add_function(wrap_pyfunction!(trapezoid_array_py, m)?)?;
363 m.add_function(wrap_pyfunction!(simpson_array_py, m)?)?;
364 m.add_function(wrap_pyfunction!(cumulative_trapezoid_py, m)?)?;
365 m.add_function(wrap_pyfunction!(romberg_array_py, m)?)?;
366 m.add_function(wrap_pyfunction!(quad_py, m)?)?;
367 m.add_function(wrap_pyfunction!(solve_ivp_py, m)?)?;
368
369 Ok(())
370}