1use pyo3::prelude::*;
6use scirs2_core::ndarray::{Array1 as Array1_17, Array2 as Array2_17};
7use scirs2_core::python::numpy_compat::{
8 scirs_to_numpy_array1, scirs_to_numpy_array2, Array1, Array2,
9};
10use scirs2_numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
11
12use scirs2_interpolate::interp1d::{ExtrapolateMode, Interp1d, InterpolationMethod};
13use scirs2_interpolate::interp2d::{Interp2d, Interp2dKind};
14use scirs2_interpolate::spline::CubicSpline;
15use scirs2_interpolate::{RBFInterpolator, RBFKernel};
16
17#[pyclass(name = "Interp1d")]
19pub struct PyInterp1d {
20 interp: Interp1d<f64>,
21}
22
23#[pymethods]
24impl PyInterp1d {
25 #[new]
33 #[pyo3(signature = (x, y, method="linear", extrapolate="error"))]
34 fn new(
35 x: PyReadonlyArray1<f64>,
36 y: PyReadonlyArray1<f64>,
37 method: &str,
38 extrapolate: &str,
39 ) -> PyResult<Self> {
40 let x_arr = x.as_array().to_owned();
42 let y_arr = y.as_array().to_owned();
43
44 let method = match method.to_lowercase().as_str() {
45 "nearest" => InterpolationMethod::Nearest,
46 "linear" => InterpolationMethod::Linear,
47 "cubic" => InterpolationMethod::Cubic,
48 "pchip" => InterpolationMethod::Pchip,
49 _ => InterpolationMethod::Linear,
50 };
51
52 let extrapolate_mode = match extrapolate.to_lowercase().as_str() {
53 "nearest" => ExtrapolateMode::Nearest,
54 "extrapolate" => ExtrapolateMode::Extrapolate,
55 _ => ExtrapolateMode::Error,
56 };
57
58 let interp = Interp1d::new(&x_arr.view(), &y_arr.view(), method, extrapolate_mode)
59 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
60
61 Ok(PyInterp1d { interp })
62 }
63
64 fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
66 let x_vec: Vec<f64> = x_new.as_array().to_vec();
67 let x_arr = Array1_17::from_vec(x_vec);
68 let result = self
69 .interp
70 .evaluate_array(&x_arr.view())
71 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
72 scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
74 }
75
76 fn eval_single(&self, x: f64) -> PyResult<f64> {
78 self.interp
79 .evaluate(x)
80 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
81 }
82}
83
84#[pyclass(name = "CubicSpline")]
86pub struct PyCubicSpline {
87 spline: CubicSpline<f64>,
88}
89
90#[pymethods]
91impl PyCubicSpline {
92 #[new]
99 #[pyo3(signature = (x, y, bc_type="natural"))]
100 fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, bc_type: &str) -> PyResult<Self> {
101 let x_vec: Vec<f64> = x.as_array().to_vec();
102 let y_vec: Vec<f64> = y.as_array().to_vec();
103 let x_arr = Array1_17::from_vec(x_vec);
104 let y_arr = Array1_17::from_vec(y_vec);
105
106 let spline = match bc_type.to_lowercase().as_str() {
107 "natural" | "not-a-knot" | "periodic" => CubicSpline::new(&x_arr.view(), &y_arr.view())
108 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?,
109 _ => {
110 return Err(pyo3::exceptions::PyValueError::new_err(format!(
111 "Unsupported boundary condition: {}",
112 bc_type
113 )));
114 }
115 };
116
117 Ok(PyCubicSpline { spline })
118 }
119
120 fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
122 let x_vec: Vec<f64> = x_new.as_array().to_vec();
123 let mut result = Vec::with_capacity(x_vec.len());
124
125 for &x in &x_vec {
126 let y = self
127 .spline
128 .evaluate(x)
129 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
130 result.push(y);
131 }
132
133 scirs_to_numpy_array1(Array1::from_vec(result), py)
134 }
135
136 fn eval_single(&self, x: f64) -> PyResult<f64> {
138 self.spline
139 .evaluate(x)
140 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
141 }
142
143 #[pyo3(signature = (x_new, nu=1))]
149 fn derivative(
150 &self,
151 py: Python,
152 x_new: PyReadonlyArray1<f64>,
153 nu: usize,
154 ) -> PyResult<Py<PyArray1<f64>>> {
155 let x_vec: Vec<f64> = x_new.as_array().to_vec();
156 let mut result = Vec::with_capacity(x_vec.len());
157
158 for &x in &x_vec {
159 let y = self
160 .spline
161 .derivative_n(x, nu)
162 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
163 result.push(y);
164 }
165
166 scirs_to_numpy_array1(Array1::from_vec(result), py)
167 }
168
169 fn integrate(&self, a: f64, b: f64) -> PyResult<f64> {
175 self.spline
176 .integrate(a, b)
177 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
178 }
179}
180
181#[pyclass(name = "Interp2d")]
183pub struct PyInterp2d {
184 interp: Interp2d<f64>,
185}
186
187#[pymethods]
188impl PyInterp2d {
189 #[new]
197 #[pyo3(signature = (x, y, z, kind="linear"))]
198 fn new(
199 x: PyReadonlyArray1<f64>,
200 y: PyReadonlyArray1<f64>,
201 z: PyReadonlyArray2<f64>,
202 kind: &str,
203 ) -> PyResult<Self> {
204 let x_vec: Vec<f64> = x.as_array().to_vec();
205 let y_vec: Vec<f64> = y.as_array().to_vec();
206 let z_arr = z.as_array();
207
208 let x_arr = Array1_17::from_vec(x_vec);
209 let y_arr = Array1_17::from_vec(y_vec);
210
211 let z_shape = z_arr.shape();
213 let z_vec: Vec<f64> = z_arr.iter().copied().collect();
214 let z_arr_17 = Array2_17::from_shape_vec((z_shape[0], z_shape[1]), z_vec).map_err(|e| {
215 pyo3::exceptions::PyValueError::new_err(format!("Invalid z array: {e}"))
216 })?;
217
218 let interp_kind = match kind.to_lowercase().as_str() {
219 "linear" => Interp2dKind::Linear,
220 "cubic" => Interp2dKind::Cubic,
221 "quintic" => Interp2dKind::Quintic,
222 _ => Interp2dKind::Linear,
223 };
224
225 let interp = Interp2d::new(&x_arr.view(), &y_arr.view(), &z_arr_17.view(), interp_kind)
226 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
227
228 Ok(PyInterp2d { interp })
229 }
230
231 fn __call__(&self, x: f64, y: f64) -> PyResult<f64> {
233 self.interp
234 .evaluate(x, y)
235 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
236 }
237
238 fn eval_array(
244 &self,
245 py: Python,
246 x_new: PyReadonlyArray1<f64>,
247 y_new: PyReadonlyArray1<f64>,
248 ) -> PyResult<Py<PyArray1<f64>>> {
249 let x_vec: Vec<f64> = x_new.as_array().to_vec();
250 let y_vec: Vec<f64> = y_new.as_array().to_vec();
251 let x_arr = Array1_17::from_vec(x_vec);
252 let y_arr = Array1_17::from_vec(y_vec);
253
254 let result = self
255 .interp
256 .evaluate_array(&x_arr.view(), &y_arr.view())
257 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
258
259 scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
260 }
261
262 fn eval_grid(
271 &self,
272 py: Python,
273 x_new: PyReadonlyArray1<f64>,
274 y_new: PyReadonlyArray1<f64>,
275 ) -> PyResult<Py<PyArray2<f64>>> {
276 let x_vec: Vec<f64> = x_new.as_array().to_vec();
277 let y_vec: Vec<f64> = y_new.as_array().to_vec();
278 let x_arr = Array1_17::from_vec(x_vec);
279 let y_arr = Array1_17::from_vec(y_vec);
280
281 let result = self
282 .interp
283 .evaluate_grid(&x_arr.view(), &y_arr.view())
284 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
285
286 let shape = result.dim();
288 let vec: Vec<f64> = result.into_iter().collect();
289 let arr2 = Array2::from_shape_vec(shape, vec).map_err(|e| {
290 pyo3::exceptions::PyRuntimeError::new_err(format!("Array reshape failed: {e}"))
291 })?;
292 scirs_to_numpy_array2(arr2, py)
293 }
294}
295
296#[pyclass(name = "RBFInterpolator")]
302pub struct PyRBFInterpolator {
303 interp: RBFInterpolator<f64>,
304}
305
306#[pymethods]
307impl PyRBFInterpolator {
308 #[new]
317 #[pyo3(signature = (points, values, kernel="gaussian", epsilon=1.0))]
318 fn new(
319 points: PyReadonlyArray2<f64>,
320 values: PyReadonlyArray1<f64>,
321 kernel: &str,
322 epsilon: f64,
323 ) -> PyResult<Self> {
324 let pts = points.as_array();
325 let vals = values.as_array();
326
327 let shape = pts.dim();
329 let pts_vec: Vec<f64> = pts.iter().copied().collect();
330 let pts_17 = Array2_17::from_shape_vec(shape, pts_vec).map_err(|e| {
331 pyo3::exceptions::PyValueError::new_err(format!("Invalid points array: {e}"))
332 })?;
333 let vals_vec: Vec<f64> = vals.iter().copied().collect();
334 let vals_17 = scirs2_core::ndarray::Array1::from_vec(vals_vec);
335
336 let rbf_kernel = match kernel.to_lowercase().as_str() {
337 "gaussian" => RBFKernel::Gaussian,
338 "multiquadric" => RBFKernel::Multiquadric,
339 "inverse_multiquadric" | "inverse-multiquadric" => RBFKernel::InverseMultiquadric,
340 "thin_plate_spline" | "thin-plate-spline" => RBFKernel::ThinPlateSpline,
341 "linear" => RBFKernel::Linear,
342 "cubic" => RBFKernel::Cubic,
343 _ => RBFKernel::Gaussian,
344 };
345
346 let interp = RBFInterpolator::new(&pts_17.view(), &vals_17.view(), rbf_kernel, epsilon)
347 .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
348
349 Ok(PyRBFInterpolator { interp })
350 }
351
352 fn __call__(
360 &self,
361 py: Python,
362 query_points: PyReadonlyArray2<f64>,
363 ) -> PyResult<Py<PyArray1<f64>>> {
364 let pts = query_points.as_array();
365 let shape = pts.dim();
366 let pts_vec: Vec<f64> = pts.iter().copied().collect();
367 let pts_17 = Array2_17::from_shape_vec(shape, pts_vec).map_err(|e| {
368 pyo3::exceptions::PyValueError::new_err(format!("Invalid query array: {e}"))
369 })?;
370
371 let result = self
372 .interp
373 .interpolate(&pts_17.view())
374 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
375
376 scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
377 }
378}
379
380#[pyfunction]
386fn interp_py(
387 py: Python,
388 x: PyReadonlyArray1<f64>,
389 xp: PyReadonlyArray1<f64>,
390 fp: PyReadonlyArray1<f64>,
391) -> PyResult<Py<PyArray1<f64>>> {
392 let x_arr = x.as_array();
393 let xp_arr = xp.as_array();
394 let fp_arr = fp.as_array();
395
396 let n = xp_arr.len();
397 if n == 0 {
398 return Err(pyo3::exceptions::PyValueError::new_err(
399 "xp must not be empty",
400 ));
401 }
402 if n != fp_arr.len() {
403 return Err(pyo3::exceptions::PyValueError::new_err(
404 "xp and fp must have same length",
405 ));
406 }
407
408 let xp_slice = xp_arr
409 .as_slice()
410 .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("xp array is not contiguous"))?;
411 let fp_slice = fp_arr
412 .as_slice()
413 .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("fp array is not contiguous"))?;
414
415 let mut result = Vec::with_capacity(x_arr.len());
417
418 for &xi in x_arr.iter() {
419 let yi = if xi <= xp_slice[0] {
420 fp_slice[0]
421 } else if xi >= xp_slice[n - 1] {
422 fp_slice[n - 1]
423 } else {
424 let idx = xp_slice.partition_point(|&v| v < xi);
426 let i = if idx > 0 { idx - 1 } else { 0 };
427
428 let x0 = xp_slice[i];
430 let x1 = xp_slice[i + 1];
431 let y0 = fp_slice[i];
432 let y1 = fp_slice[i + 1];
433 let t = (xi - x0) / (x1 - x0);
434 y0 + t * (y1 - y0)
435 };
436 result.push(yi);
437 }
438
439 scirs_to_numpy_array1(Array1::from_vec(result), py)
440}
441
442#[pyfunction]
444#[pyo3(signature = (x, xp, fp, left=None, right=None))]
445fn interp_with_bounds_py(
446 py: Python,
447 x: PyReadonlyArray1<f64>,
448 xp: PyReadonlyArray1<f64>,
449 fp: PyReadonlyArray1<f64>,
450 left: Option<f64>,
451 right: Option<f64>,
452) -> PyResult<Py<PyArray1<f64>>> {
453 let x_arr = x.as_array();
454 let xp_arr = xp.as_array();
455 let fp_arr = fp.as_array();
456
457 let n = xp_arr.len();
458 if n == 0 || fp_arr.len() != n {
459 return Err(pyo3::exceptions::PyValueError::new_err(
460 "Invalid input arrays",
461 ));
462 }
463
464 let mut result = Vec::with_capacity(x_arr.len());
465
466 for &xi in x_arr.iter() {
467 let yi = if xi < xp_arr[0] {
468 left.unwrap_or(fp_arr[0])
469 } else if xi > xp_arr[n - 1] {
470 right.unwrap_or(fp_arr[n - 1])
471 } else {
472 let mut lo = 0;
474 let mut hi = n - 1;
475 while hi - lo > 1 {
476 let mid = (lo + hi) / 2;
477 if xp_arr[mid] <= xi {
478 lo = mid;
479 } else {
480 hi = mid;
481 }
482 }
483 let t = (xi - xp_arr[lo]) / (xp_arr[hi] - xp_arr[lo]);
485 fp_arr[lo] * (1.0 - t) + fp_arr[hi] * t
486 };
487 result.push(yi);
488 }
489
490 scirs_to_numpy_array1(Array1::from_vec(result), py)
491}
492
493pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
495 m.add_class::<PyInterp1d>()?;
496 m.add_class::<PyCubicSpline>()?;
497 m.add_class::<PyInterp2d>()?;
498 m.add_class::<PyRBFInterpolator>()?;
499 m.add_function(wrap_pyfunction!(interp_py, m)?)?;
500 m.add_function(wrap_pyfunction!(interp_with_bounds_py, m)?)?;
501
502 Ok(())
503}