Skip to main content

scirs2/
interpolate.rs

1//! Python bindings for scirs2-interpolate
2//!
3//! Provides interpolation methods similar to scipy.interpolate
4
5use pyo3::prelude::*;
6use scirs2_core::ndarray::{Array1 as Array1_17, Array2 as Array2_17};
7use scirs2_core::python::numpy_compat::{
8    scirs_to_numpy_array1, scirs_to_numpy_array2, Array1, Array2,
9};
10use scirs2_numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
11
12use scirs2_interpolate::interp1d::{ExtrapolateMode, Interp1d, InterpolationMethod};
13use scirs2_interpolate::interp2d::{Interp2d, Interp2dKind};
14use scirs2_interpolate::spline::CubicSpline;
15use scirs2_interpolate::{RBFInterpolator, RBFKernel};
16
17/// 1D interpolation class
18#[pyclass(name = "Interp1d")]
19pub struct PyInterp1d {
20    interp: Interp1d<f64>,
21}
22
23#[pymethods]
24impl PyInterp1d {
25    /// Create a new 1D interpolator
26    ///
27    /// Parameters:
28    /// - x: x coordinates (must be sorted)
29    /// - y: y coordinates
30    /// - method: 'linear', 'nearest', 'cubic', or 'pchip'
31    /// - extrapolate: 'error', 'const', or 'extrapolate'
32    #[new]
33    #[pyo3(signature = (x, y, method="linear", extrapolate="error"))]
34    fn new(
35        x: PyReadonlyArray1<f64>,
36        y: PyReadonlyArray1<f64>,
37        method: &str,
38        extrapolate: &str,
39    ) -> PyResult<Self> {
40        // Convert from numpy arrays to scirs2-core ndarray17 (single copy)
41        let x_arr = x.as_array().to_owned();
42        let y_arr = y.as_array().to_owned();
43
44        let method = match method.to_lowercase().as_str() {
45            "nearest" => InterpolationMethod::Nearest,
46            "linear" => InterpolationMethod::Linear,
47            "cubic" => InterpolationMethod::Cubic,
48            "pchip" => InterpolationMethod::Pchip,
49            _ => InterpolationMethod::Linear,
50        };
51
52        let extrapolate_mode = match extrapolate.to_lowercase().as_str() {
53            "nearest" => ExtrapolateMode::Nearest,
54            "extrapolate" => ExtrapolateMode::Extrapolate,
55            _ => ExtrapolateMode::Error,
56        };
57
58        let interp = Interp1d::new(&x_arr.view(), &y_arr.view(), method, extrapolate_mode)
59            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
60
61        Ok(PyInterp1d { interp })
62    }
63
64    /// Evaluate the interpolator at new points
65    fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
66        let x_vec: Vec<f64> = x_new.as_array().to_vec();
67        let x_arr = Array1_17::from_vec(x_vec);
68        let result = self
69            .interp
70            .evaluate_array(&x_arr.view())
71            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
72        // Convert back to numpy-compatible array
73        scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
74    }
75
76    /// Evaluate at a single point
77    fn eval_single(&self, x: f64) -> PyResult<f64> {
78        self.interp
79            .evaluate(x)
80            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
81    }
82}
83
84/// Cubic spline interpolation class
85#[pyclass(name = "CubicSpline")]
86pub struct PyCubicSpline {
87    spline: CubicSpline<f64>,
88}
89
90#[pymethods]
91impl PyCubicSpline {
92    /// Create a new cubic spline interpolator
93    ///
94    /// Parameters:
95    /// - x: x coordinates (must be sorted, strictly increasing)
96    /// - y: y coordinates
97    /// - bc_type: boundary condition type ('natural', 'not-a-knot', 'clamped', 'periodic')
98    #[new]
99    #[pyo3(signature = (x, y, bc_type="natural"))]
100    fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, bc_type: &str) -> PyResult<Self> {
101        let x_vec: Vec<f64> = x.as_array().to_vec();
102        let y_vec: Vec<f64> = y.as_array().to_vec();
103        let x_arr = Array1_17::from_vec(x_vec);
104        let y_arr = Array1_17::from_vec(y_vec);
105
106        let spline = match bc_type.to_lowercase().as_str() {
107            "natural" | "not-a-knot" | "periodic" => CubicSpline::new(&x_arr.view(), &y_arr.view())
108                .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?,
109            _ => {
110                return Err(pyo3::exceptions::PyValueError::new_err(format!(
111                    "Unsupported boundary condition: {}",
112                    bc_type
113                )));
114            }
115        };
116
117        Ok(PyCubicSpline { spline })
118    }
119
120    /// Evaluate the spline at new points
121    fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
122        let x_vec: Vec<f64> = x_new.as_array().to_vec();
123        let mut result = Vec::with_capacity(x_vec.len());
124
125        for &x in &x_vec {
126            let y = self
127                .spline
128                .evaluate(x)
129                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
130            result.push(y);
131        }
132
133        scirs_to_numpy_array1(Array1::from_vec(result), py)
134    }
135
136    /// Evaluate at a single point
137    fn eval_single(&self, x: f64) -> PyResult<f64> {
138        self.spline
139            .evaluate(x)
140            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
141    }
142
143    /// Compute derivative at new points
144    ///
145    /// Parameters:
146    /// - x_new: points to evaluate derivative
147    /// - nu: derivative order (default: 1)
148    #[pyo3(signature = (x_new, nu=1))]
149    fn derivative(
150        &self,
151        py: Python,
152        x_new: PyReadonlyArray1<f64>,
153        nu: usize,
154    ) -> PyResult<Py<PyArray1<f64>>> {
155        let x_vec: Vec<f64> = x_new.as_array().to_vec();
156        let mut result = Vec::with_capacity(x_vec.len());
157
158        for &x in &x_vec {
159            let y = self
160                .spline
161                .derivative_n(x, nu)
162                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
163            result.push(y);
164        }
165
166        scirs_to_numpy_array1(Array1::from_vec(result), py)
167    }
168
169    /// Integrate the spline over an interval
170    ///
171    /// Parameters:
172    /// - a: lower bound
173    /// - b: upper bound
174    fn integrate(&self, a: f64, b: f64) -> PyResult<f64> {
175        self.spline
176            .integrate(a, b)
177            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
178    }
179}
180
181/// 2D interpolation class
182#[pyclass(name = "Interp2d")]
183pub struct PyInterp2d {
184    interp: Interp2d<f64>,
185}
186
187#[pymethods]
188impl PyInterp2d {
189    /// Create a new 2D interpolator
190    ///
191    /// Parameters:
192    /// - x: x coordinates (must be sorted)
193    /// - y: y coordinates (must be sorted)
194    /// - z: z values with shape (len(y), len(x))
195    /// - kind: interpolation method ('linear', 'cubic', 'quintic')
196    #[new]
197    #[pyo3(signature = (x, y, z, kind="linear"))]
198    fn new(
199        x: PyReadonlyArray1<f64>,
200        y: PyReadonlyArray1<f64>,
201        z: PyReadonlyArray2<f64>,
202        kind: &str,
203    ) -> PyResult<Self> {
204        let x_vec: Vec<f64> = x.as_array().to_vec();
205        let y_vec: Vec<f64> = y.as_array().to_vec();
206        let z_arr = z.as_array();
207
208        let x_arr = Array1_17::from_vec(x_vec);
209        let y_arr = Array1_17::from_vec(y_vec);
210
211        // Convert to Array2_17
212        let z_shape = z_arr.shape();
213        let z_vec: Vec<f64> = z_arr.iter().copied().collect();
214        let z_arr_17 = Array2_17::from_shape_vec((z_shape[0], z_shape[1]), z_vec).map_err(|e| {
215            pyo3::exceptions::PyValueError::new_err(format!("Invalid z array: {e}"))
216        })?;
217
218        let interp_kind = match kind.to_lowercase().as_str() {
219            "linear" => Interp2dKind::Linear,
220            "cubic" => Interp2dKind::Cubic,
221            "quintic" => Interp2dKind::Quintic,
222            _ => Interp2dKind::Linear,
223        };
224
225        let interp = Interp2d::new(&x_arr.view(), &y_arr.view(), &z_arr_17.view(), interp_kind)
226            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
227
228        Ok(PyInterp2d { interp })
229    }
230
231    /// Evaluate at a single point (x, y)
232    fn __call__(&self, x: f64, y: f64) -> PyResult<f64> {
233        self.interp
234            .evaluate(x, y)
235            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
236    }
237
238    /// Evaluate at multiple points
239    ///
240    /// Parameters:
241    /// - x_new: x coordinates
242    /// - y_new: y coordinates (must have same length as x_new)
243    fn eval_array(
244        &self,
245        py: Python,
246        x_new: PyReadonlyArray1<f64>,
247        y_new: PyReadonlyArray1<f64>,
248    ) -> PyResult<Py<PyArray1<f64>>> {
249        let x_vec: Vec<f64> = x_new.as_array().to_vec();
250        let y_vec: Vec<f64> = y_new.as_array().to_vec();
251        let x_arr = Array1_17::from_vec(x_vec);
252        let y_arr = Array1_17::from_vec(y_vec);
253
254        let result = self
255            .interp
256            .evaluate_array(&x_arr.view(), &y_arr.view())
257            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
258
259        scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
260    }
261
262    /// Evaluate on a regular grid
263    ///
264    /// Parameters:
265    /// - x_new: x coordinates for output grid
266    /// - y_new: y coordinates for output grid
267    ///
268    /// Returns:
269    /// - 2D array with shape (len(y_new), len(x_new))
270    fn eval_grid(
271        &self,
272        py: Python,
273        x_new: PyReadonlyArray1<f64>,
274        y_new: PyReadonlyArray1<f64>,
275    ) -> PyResult<Py<PyArray2<f64>>> {
276        let x_vec: Vec<f64> = x_new.as_array().to_vec();
277        let y_vec: Vec<f64> = y_new.as_array().to_vec();
278        let x_arr = Array1_17::from_vec(x_vec);
279        let y_arr = Array1_17::from_vec(y_vec);
280
281        let result = self
282            .interp
283            .evaluate_grid(&x_arr.view(), &y_arr.view())
284            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
285
286        // Convert Array2_17 to Array2 (ndarray 0.16)
287        let shape = result.dim();
288        let vec: Vec<f64> = result.into_iter().collect();
289        let arr2 = Array2::from_shape_vec(shape, vec).map_err(|e| {
290            pyo3::exceptions::PyRuntimeError::new_err(format!("Array reshape failed: {e}"))
291        })?;
292        scirs_to_numpy_array2(arr2, py)
293    }
294}
295
296// =============================================================================
297// RBF Interpolation
298// =============================================================================
299
300/// Radial Basis Function (RBF) interpolation class
301#[pyclass(name = "RBFInterpolator")]
302pub struct PyRBFInterpolator {
303    interp: RBFInterpolator<f64>,
304}
305
306#[pymethods]
307impl PyRBFInterpolator {
308    /// Create a new RBF interpolator
309    ///
310    /// Parameters:
311    /// - points: Training data points, shape (n_samples, n_features)
312    /// - values: Training data values, shape (n_samples,)
313    /// - kernel: Kernel type - 'gaussian', 'multiquadric', 'inverse_multiquadric',
314    ///           'thin_plate_spline', 'linear', 'cubic' (default: 'gaussian')
315    /// - epsilon: Shape parameter for the kernel (default: 1.0)
316    #[new]
317    #[pyo3(signature = (points, values, kernel="gaussian", epsilon=1.0))]
318    fn new(
319        points: PyReadonlyArray2<f64>,
320        values: PyReadonlyArray1<f64>,
321        kernel: &str,
322        epsilon: f64,
323    ) -> PyResult<Self> {
324        let pts = points.as_array();
325        let vals = values.as_array();
326
327        // Convert to ndarray 0.17 types
328        let shape = pts.dim();
329        let pts_vec: Vec<f64> = pts.iter().copied().collect();
330        let pts_17 = Array2_17::from_shape_vec(shape, pts_vec).map_err(|e| {
331            pyo3::exceptions::PyValueError::new_err(format!("Invalid points array: {e}"))
332        })?;
333        let vals_vec: Vec<f64> = vals.iter().copied().collect();
334        let vals_17 = scirs2_core::ndarray::Array1::from_vec(vals_vec);
335
336        let rbf_kernel = match kernel.to_lowercase().as_str() {
337            "gaussian" => RBFKernel::Gaussian,
338            "multiquadric" => RBFKernel::Multiquadric,
339            "inverse_multiquadric" | "inverse-multiquadric" => RBFKernel::InverseMultiquadric,
340            "thin_plate_spline" | "thin-plate-spline" => RBFKernel::ThinPlateSpline,
341            "linear" => RBFKernel::Linear,
342            "cubic" => RBFKernel::Cubic,
343            _ => RBFKernel::Gaussian,
344        };
345
346        let interp = RBFInterpolator::new(&pts_17.view(), &vals_17.view(), rbf_kernel, epsilon)
347            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
348
349        Ok(PyRBFInterpolator { interp })
350    }
351
352    /// Evaluate the RBF interpolator at new points
353    ///
354    /// Parameters:
355    /// - query_points: Points to evaluate, shape (n_query, n_features)
356    ///
357    /// Returns:
358    /// - 1D array of interpolated values, shape (n_query,)
359    fn __call__(
360        &self,
361        py: Python,
362        query_points: PyReadonlyArray2<f64>,
363    ) -> PyResult<Py<PyArray1<f64>>> {
364        let pts = query_points.as_array();
365        let shape = pts.dim();
366        let pts_vec: Vec<f64> = pts.iter().copied().collect();
367        let pts_17 = Array2_17::from_shape_vec(shape, pts_vec).map_err(|e| {
368            pyo3::exceptions::PyValueError::new_err(format!("Invalid query array: {e}"))
369        })?;
370
371        let result = self
372            .interp
373            .interpolate(&pts_17.view())
374            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
375
376        scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
377    }
378}
379
380// =============================================================================
381// Simple interpolation functions
382// =============================================================================
383
384/// Linear interpolation - optimized direct implementation
385#[pyfunction]
386fn interp_py(
387    py: Python,
388    x: PyReadonlyArray1<f64>,
389    xp: PyReadonlyArray1<f64>,
390    fp: PyReadonlyArray1<f64>,
391) -> PyResult<Py<PyArray1<f64>>> {
392    let x_arr = x.as_array();
393    let xp_arr = xp.as_array();
394    let fp_arr = fp.as_array();
395
396    let n = xp_arr.len();
397    if n == 0 {
398        return Err(pyo3::exceptions::PyValueError::new_err(
399            "xp must not be empty",
400        ));
401    }
402    if n != fp_arr.len() {
403        return Err(pyo3::exceptions::PyValueError::new_err(
404            "xp and fp must have same length",
405        ));
406    }
407
408    let xp_slice = xp_arr
409        .as_slice()
410        .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("xp array is not contiguous"))?;
411    let fp_slice = fp_arr
412        .as_slice()
413        .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("fp array is not contiguous"))?;
414
415    // Pre-allocate result
416    let mut result = Vec::with_capacity(x_arr.len());
417
418    for &xi in x_arr.iter() {
419        let yi = if xi <= xp_slice[0] {
420            fp_slice[0]
421        } else if xi >= xp_slice[n - 1] {
422            fp_slice[n - 1]
423        } else {
424            // Binary search for interval
425            let idx = xp_slice.partition_point(|&v| v < xi);
426            let i = if idx > 0 { idx - 1 } else { 0 };
427
428            // Linear interpolation
429            let x0 = xp_slice[i];
430            let x1 = xp_slice[i + 1];
431            let y0 = fp_slice[i];
432            let y1 = fp_slice[i + 1];
433            let t = (xi - x0) / (x1 - x0);
434            y0 + t * (y1 - y0)
435        };
436        result.push(yi);
437    }
438
439    scirs_to_numpy_array1(Array1::from_vec(result), py)
440}
441
442/// Piecewise linear interpolation with boundary handling
443#[pyfunction]
444#[pyo3(signature = (x, xp, fp, left=None, right=None))]
445fn interp_with_bounds_py(
446    py: Python,
447    x: PyReadonlyArray1<f64>,
448    xp: PyReadonlyArray1<f64>,
449    fp: PyReadonlyArray1<f64>,
450    left: Option<f64>,
451    right: Option<f64>,
452) -> PyResult<Py<PyArray1<f64>>> {
453    let x_arr = x.as_array();
454    let xp_arr = xp.as_array();
455    let fp_arr = fp.as_array();
456
457    let n = xp_arr.len();
458    if n == 0 || fp_arr.len() != n {
459        return Err(pyo3::exceptions::PyValueError::new_err(
460            "Invalid input arrays",
461        ));
462    }
463
464    let mut result = Vec::with_capacity(x_arr.len());
465
466    for &xi in x_arr.iter() {
467        let yi = if xi < xp_arr[0] {
468            left.unwrap_or(fp_arr[0])
469        } else if xi > xp_arr[n - 1] {
470            right.unwrap_or(fp_arr[n - 1])
471        } else {
472            // Binary search for interval
473            let mut lo = 0;
474            let mut hi = n - 1;
475            while hi - lo > 1 {
476                let mid = (lo + hi) / 2;
477                if xp_arr[mid] <= xi {
478                    lo = mid;
479                } else {
480                    hi = mid;
481                }
482            }
483            // Linear interpolation
484            let t = (xi - xp_arr[lo]) / (xp_arr[hi] - xp_arr[lo]);
485            fp_arr[lo] * (1.0 - t) + fp_arr[hi] * t
486        };
487        result.push(yi);
488    }
489
490    scirs_to_numpy_array1(Array1::from_vec(result), py)
491}
492
493/// Python module registration
494pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
495    m.add_class::<PyInterp1d>()?;
496    m.add_class::<PyCubicSpline>()?;
497    m.add_class::<PyInterp2d>()?;
498    m.add_class::<PyRBFInterpolator>()?;
499    m.add_function(wrap_pyfunction!(interp_py, m)?)?;
500    m.add_function(wrap_pyfunction!(interp_with_bounds_py, m)?)?;
501
502    Ok(())
503}