1use pyo3::prelude::*;
6use scirs2_core::ndarray::{Array1 as Array1_17, Array2 as Array2_17};
7use scirs2_core::python::numpy_compat::{
8 scirs_to_numpy_array1, scirs_to_numpy_array2, Array1, Array2,
9};
10use scirs2_numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
11
12use scirs2_interpolate::interp1d::{ExtrapolateMode, Interp1d, InterpolationMethod};
13use scirs2_interpolate::interp2d::{Interp2d, Interp2dKind};
14use scirs2_interpolate::spline::CubicSpline;
15
16#[pyclass(name = "Interp1d")]
18pub struct PyInterp1d {
19 interp: Interp1d<f64>,
20}
21
22#[pymethods]
23impl PyInterp1d {
24 #[new]
32 #[pyo3(signature = (x, y, method="linear", extrapolate="error"))]
33 fn new(
34 x: PyReadonlyArray1<f64>,
35 y: PyReadonlyArray1<f64>,
36 method: &str,
37 extrapolate: &str,
38 ) -> PyResult<Self> {
39 let x_arr = x.as_array().to_owned();
41 let y_arr = y.as_array().to_owned();
42
43 let method = match method.to_lowercase().as_str() {
44 "nearest" => InterpolationMethod::Nearest,
45 "linear" => InterpolationMethod::Linear,
46 "cubic" => InterpolationMethod::Cubic,
47 "pchip" => InterpolationMethod::Pchip,
48 _ => InterpolationMethod::Linear,
49 };
50
51 let extrapolate_mode = match extrapolate.to_lowercase().as_str() {
52 "nearest" => ExtrapolateMode::Nearest,
53 "extrapolate" => ExtrapolateMode::Extrapolate,
54 _ => ExtrapolateMode::Error,
55 };
56
57 let interp = Interp1d::new(&x_arr.view(), &y_arr.view(), method, extrapolate_mode)
58 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
59
60 Ok(PyInterp1d { interp })
61 }
62
63 fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
65 let x_vec: Vec<f64> = x_new.as_array().to_vec();
66 let x_arr = Array1_17::from_vec(x_vec);
67 let result = self
68 .interp
69 .evaluate_array(&x_arr.view())
70 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
71 scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
73 }
74
75 fn eval_single(&self, x: f64) -> PyResult<f64> {
77 self.interp
78 .evaluate(x)
79 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
80 }
81}
82
83#[pyclass(name = "CubicSpline")]
85pub struct PyCubicSpline {
86 spline: CubicSpline<f64>,
87}
88
89#[pymethods]
90impl PyCubicSpline {
91 #[new]
98 #[pyo3(signature = (x, y, bc_type="natural"))]
99 fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, bc_type: &str) -> PyResult<Self> {
100 let x_vec: Vec<f64> = x.as_array().to_vec();
101 let y_vec: Vec<f64> = y.as_array().to_vec();
102 let x_arr = Array1_17::from_vec(x_vec);
103 let y_arr = Array1_17::from_vec(y_vec);
104
105 let spline = match bc_type.to_lowercase().as_str() {
106 "natural" | "not-a-knot" | "periodic" => CubicSpline::new(&x_arr.view(), &y_arr.view())
107 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?,
108 _ => {
109 return Err(pyo3::exceptions::PyValueError::new_err(format!(
110 "Unsupported boundary condition: {}",
111 bc_type
112 )));
113 }
114 };
115
116 Ok(PyCubicSpline { spline })
117 }
118
119 fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
121 let x_vec: Vec<f64> = x_new.as_array().to_vec();
122 let mut result = Vec::with_capacity(x_vec.len());
123
124 for &x in &x_vec {
125 let y = self
126 .spline
127 .evaluate(x)
128 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
129 result.push(y);
130 }
131
132 scirs_to_numpy_array1(Array1::from_vec(result), py)
133 }
134
135 fn eval_single(&self, x: f64) -> PyResult<f64> {
137 self.spline
138 .evaluate(x)
139 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
140 }
141
142 #[pyo3(signature = (x_new, nu=1))]
148 fn derivative(
149 &self,
150 py: Python,
151 x_new: PyReadonlyArray1<f64>,
152 nu: usize,
153 ) -> PyResult<Py<PyArray1<f64>>> {
154 let x_vec: Vec<f64> = x_new.as_array().to_vec();
155 let mut result = Vec::with_capacity(x_vec.len());
156
157 for &x in &x_vec {
158 let y = self
159 .spline
160 .derivative_n(x, nu)
161 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
162 result.push(y);
163 }
164
165 scirs_to_numpy_array1(Array1::from_vec(result), py)
166 }
167
168 fn integrate(&self, a: f64, b: f64) -> PyResult<f64> {
174 self.spline
175 .integrate(a, b)
176 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
177 }
178}
179
180#[pyclass(name = "Interp2d")]
182pub struct PyInterp2d {
183 interp: Interp2d<f64>,
184}
185
186#[pymethods]
187impl PyInterp2d {
188 #[new]
196 #[pyo3(signature = (x, y, z, kind="linear"))]
197 fn new(
198 x: PyReadonlyArray1<f64>,
199 y: PyReadonlyArray1<f64>,
200 z: PyReadonlyArray2<f64>,
201 kind: &str,
202 ) -> PyResult<Self> {
203 let x_vec: Vec<f64> = x.as_array().to_vec();
204 let y_vec: Vec<f64> = y.as_array().to_vec();
205 let z_arr = z.as_array();
206
207 let x_arr = Array1_17::from_vec(x_vec);
208 let y_arr = Array1_17::from_vec(y_vec);
209
210 let z_shape = z_arr.shape();
212 let z_vec: Vec<f64> = z_arr.iter().copied().collect();
213 let z_arr_17 = Array2_17::from_shape_vec((z_shape[0], z_shape[1]), z_vec).map_err(|e| {
214 pyo3::exceptions::PyValueError::new_err(format!("Invalid z array: {e}"))
215 })?;
216
217 let interp_kind = match kind.to_lowercase().as_str() {
218 "linear" => Interp2dKind::Linear,
219 "cubic" => Interp2dKind::Cubic,
220 "quintic" => Interp2dKind::Quintic,
221 _ => Interp2dKind::Linear,
222 };
223
224 let interp = Interp2d::new(&x_arr.view(), &y_arr.view(), &z_arr_17.view(), interp_kind)
225 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
226
227 Ok(PyInterp2d { interp })
228 }
229
230 fn __call__(&self, x: f64, y: f64) -> PyResult<f64> {
232 self.interp
233 .evaluate(x, y)
234 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
235 }
236
237 fn eval_array(
243 &self,
244 py: Python,
245 x_new: PyReadonlyArray1<f64>,
246 y_new: PyReadonlyArray1<f64>,
247 ) -> PyResult<Py<PyArray1<f64>>> {
248 let x_vec: Vec<f64> = x_new.as_array().to_vec();
249 let y_vec: Vec<f64> = y_new.as_array().to_vec();
250 let x_arr = Array1_17::from_vec(x_vec);
251 let y_arr = Array1_17::from_vec(y_vec);
252
253 let result = self
254 .interp
255 .evaluate_array(&x_arr.view(), &y_arr.view())
256 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
257
258 scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
259 }
260
261 fn eval_grid(
270 &self,
271 py: Python,
272 x_new: PyReadonlyArray1<f64>,
273 y_new: PyReadonlyArray1<f64>,
274 ) -> PyResult<Py<PyArray2<f64>>> {
275 let x_vec: Vec<f64> = x_new.as_array().to_vec();
276 let y_vec: Vec<f64> = y_new.as_array().to_vec();
277 let x_arr = Array1_17::from_vec(x_vec);
278 let y_arr = Array1_17::from_vec(y_vec);
279
280 let result = self
281 .interp
282 .evaluate_grid(&x_arr.view(), &y_arr.view())
283 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
284
285 let shape = result.dim();
287 let vec: Vec<f64> = result.into_iter().collect();
288 scirs_to_numpy_array2(
289 Array2::from_shape_vec(shape, vec).expect("Operation failed"),
290 py,
291 )
292 }
293}
294
295#[pyfunction]
301fn interp_py(
302 py: Python,
303 x: PyReadonlyArray1<f64>,
304 xp: PyReadonlyArray1<f64>,
305 fp: PyReadonlyArray1<f64>,
306) -> PyResult<Py<PyArray1<f64>>> {
307 let x_arr = x.as_array();
308 let xp_arr = xp.as_array();
309 let fp_arr = fp.as_array();
310
311 let n = xp_arr.len();
312 if n == 0 {
313 return Err(pyo3::exceptions::PyValueError::new_err(
314 "xp must not be empty",
315 ));
316 }
317 if n != fp_arr.len() {
318 return Err(pyo3::exceptions::PyValueError::new_err(
319 "xp and fp must have same length",
320 ));
321 }
322
323 let xp_slice = xp_arr.as_slice().expect("Operation failed");
324 let fp_slice = fp_arr.as_slice().expect("Operation failed");
325
326 let mut result = Vec::with_capacity(x_arr.len());
328
329 for &xi in x_arr.iter() {
330 let yi = if xi <= xp_slice[0] {
331 fp_slice[0]
332 } else if xi >= xp_slice[n - 1] {
333 fp_slice[n - 1]
334 } else {
335 let idx = xp_slice.partition_point(|&v| v < xi);
337 let i = if idx > 0 { idx - 1 } else { 0 };
338
339 let x0 = xp_slice[i];
341 let x1 = xp_slice[i + 1];
342 let y0 = fp_slice[i];
343 let y1 = fp_slice[i + 1];
344 let t = (xi - x0) / (x1 - x0);
345 y0 + t * (y1 - y0)
346 };
347 result.push(yi);
348 }
349
350 scirs_to_numpy_array1(Array1::from_vec(result), py)
351}
352
353#[pyfunction]
355#[pyo3(signature = (x, xp, fp, left=None, right=None))]
356fn interp_with_bounds_py(
357 py: Python,
358 x: PyReadonlyArray1<f64>,
359 xp: PyReadonlyArray1<f64>,
360 fp: PyReadonlyArray1<f64>,
361 left: Option<f64>,
362 right: Option<f64>,
363) -> PyResult<Py<PyArray1<f64>>> {
364 let x_arr = x.as_array();
365 let xp_arr = xp.as_array();
366 let fp_arr = fp.as_array();
367
368 let n = xp_arr.len();
369 if n == 0 || fp_arr.len() != n {
370 return Err(pyo3::exceptions::PyValueError::new_err(
371 "Invalid input arrays",
372 ));
373 }
374
375 let mut result = Vec::with_capacity(x_arr.len());
376
377 for &xi in x_arr.iter() {
378 let yi = if xi < xp_arr[0] {
379 left.unwrap_or(fp_arr[0])
380 } else if xi > xp_arr[n - 1] {
381 right.unwrap_or(fp_arr[n - 1])
382 } else {
383 let mut lo = 0;
385 let mut hi = n - 1;
386 while hi - lo > 1 {
387 let mid = (lo + hi) / 2;
388 if xp_arr[mid] <= xi {
389 lo = mid;
390 } else {
391 hi = mid;
392 }
393 }
394 let t = (xi - xp_arr[lo]) / (xp_arr[hi] - xp_arr[lo]);
396 fp_arr[lo] * (1.0 - t) + fp_arr[hi] * t
397 };
398 result.push(yi);
399 }
400
401 scirs_to_numpy_array1(Array1::from_vec(result), py)
402}
403
404pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
406 m.add_class::<PyInterp1d>()?;
407 m.add_class::<PyCubicSpline>()?;
408 m.add_class::<PyInterp2d>()?;
409 m.add_function(wrap_pyfunction!(interp_py, m)?)?;
410 m.add_function(wrap_pyfunction!(interp_with_bounds_py, m)?)?;
411
412 Ok(())
413}