Skip to main content

scirs2/
interpolate.rs

1//! Python bindings for scirs2-interpolate
2//!
3//! Provides interpolation methods similar to scipy.interpolate
4
5use pyo3::prelude::*;
6use scirs2_core::ndarray::{Array1 as Array1_17, Array2 as Array2_17};
7use scirs2_core::python::numpy_compat::{
8    scirs_to_numpy_array1, scirs_to_numpy_array2, Array1, Array2,
9};
10use scirs2_numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2};
11
12use scirs2_interpolate::interp1d::{ExtrapolateMode, Interp1d, InterpolationMethod};
13use scirs2_interpolate::interp2d::{Interp2d, Interp2dKind};
14use scirs2_interpolate::spline::CubicSpline;
15
16/// 1D interpolation class
17#[pyclass(name = "Interp1d")]
18pub struct PyInterp1d {
19    interp: Interp1d<f64>,
20}
21
22#[pymethods]
23impl PyInterp1d {
24    /// Create a new 1D interpolator
25    ///
26    /// Parameters:
27    /// - x: x coordinates (must be sorted)
28    /// - y: y coordinates
29    /// - method: 'linear', 'nearest', 'cubic', or 'pchip'
30    /// - extrapolate: 'error', 'const', or 'extrapolate'
31    #[new]
32    #[pyo3(signature = (x, y, method="linear", extrapolate="error"))]
33    fn new(
34        x: PyReadonlyArray1<f64>,
35        y: PyReadonlyArray1<f64>,
36        method: &str,
37        extrapolate: &str,
38    ) -> PyResult<Self> {
39        // Convert from numpy arrays to scirs2-core ndarray17 (single copy)
40        let x_arr = x.as_array().to_owned();
41        let y_arr = y.as_array().to_owned();
42
43        let method = match method.to_lowercase().as_str() {
44            "nearest" => InterpolationMethod::Nearest,
45            "linear" => InterpolationMethod::Linear,
46            "cubic" => InterpolationMethod::Cubic,
47            "pchip" => InterpolationMethod::Pchip,
48            _ => InterpolationMethod::Linear,
49        };
50
51        let extrapolate_mode = match extrapolate.to_lowercase().as_str() {
52            "nearest" => ExtrapolateMode::Nearest,
53            "extrapolate" => ExtrapolateMode::Extrapolate,
54            _ => ExtrapolateMode::Error,
55        };
56
57        let interp = Interp1d::new(&x_arr.view(), &y_arr.view(), method, extrapolate_mode)
58            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
59
60        Ok(PyInterp1d { interp })
61    }
62
63    /// Evaluate the interpolator at new points
64    fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
65        let x_vec: Vec<f64> = x_new.as_array().to_vec();
66        let x_arr = Array1_17::from_vec(x_vec);
67        let result = self
68            .interp
69            .evaluate_array(&x_arr.view())
70            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
71        // Convert back to numpy-compatible array
72        scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
73    }
74
75    /// Evaluate at a single point
76    fn eval_single(&self, x: f64) -> PyResult<f64> {
77        self.interp
78            .evaluate(x)
79            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
80    }
81}
82
83/// Cubic spline interpolation class
84#[pyclass(name = "CubicSpline")]
85pub struct PyCubicSpline {
86    spline: CubicSpline<f64>,
87}
88
89#[pymethods]
90impl PyCubicSpline {
91    /// Create a new cubic spline interpolator
92    ///
93    /// Parameters:
94    /// - x: x coordinates (must be sorted, strictly increasing)
95    /// - y: y coordinates
96    /// - bc_type: boundary condition type ('natural', 'not-a-knot', 'clamped', 'periodic')
97    #[new]
98    #[pyo3(signature = (x, y, bc_type="natural"))]
99    fn new(x: PyReadonlyArray1<f64>, y: PyReadonlyArray1<f64>, bc_type: &str) -> PyResult<Self> {
100        let x_vec: Vec<f64> = x.as_array().to_vec();
101        let y_vec: Vec<f64> = y.as_array().to_vec();
102        let x_arr = Array1_17::from_vec(x_vec);
103        let y_arr = Array1_17::from_vec(y_vec);
104
105        let spline = match bc_type.to_lowercase().as_str() {
106            "natural" | "not-a-knot" | "periodic" => CubicSpline::new(&x_arr.view(), &y_arr.view())
107                .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?,
108            _ => {
109                return Err(pyo3::exceptions::PyValueError::new_err(format!(
110                    "Unsupported boundary condition: {}",
111                    bc_type
112                )));
113            }
114        };
115
116        Ok(PyCubicSpline { spline })
117    }
118
119    /// Evaluate the spline at new points
120    fn __call__(&self, py: Python, x_new: PyReadonlyArray1<f64>) -> PyResult<Py<PyArray1<f64>>> {
121        let x_vec: Vec<f64> = x_new.as_array().to_vec();
122        let mut result = Vec::with_capacity(x_vec.len());
123
124        for &x in &x_vec {
125            let y = self
126                .spline
127                .evaluate(x)
128                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
129            result.push(y);
130        }
131
132        scirs_to_numpy_array1(Array1::from_vec(result), py)
133    }
134
135    /// Evaluate at a single point
136    fn eval_single(&self, x: f64) -> PyResult<f64> {
137        self.spline
138            .evaluate(x)
139            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
140    }
141
142    /// Compute derivative at new points
143    ///
144    /// Parameters:
145    /// - x_new: points to evaluate derivative
146    /// - nu: derivative order (default: 1)
147    #[pyo3(signature = (x_new, nu=1))]
148    fn derivative(
149        &self,
150        py: Python,
151        x_new: PyReadonlyArray1<f64>,
152        nu: usize,
153    ) -> PyResult<Py<PyArray1<f64>>> {
154        let x_vec: Vec<f64> = x_new.as_array().to_vec();
155        let mut result = Vec::with_capacity(x_vec.len());
156
157        for &x in &x_vec {
158            let y = self
159                .spline
160                .derivative_n(x, nu)
161                .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
162            result.push(y);
163        }
164
165        scirs_to_numpy_array1(Array1::from_vec(result), py)
166    }
167
168    /// Integrate the spline over an interval
169    ///
170    /// Parameters:
171    /// - a: lower bound
172    /// - b: upper bound
173    fn integrate(&self, a: f64, b: f64) -> PyResult<f64> {
174        self.spline
175            .integrate(a, b)
176            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
177    }
178}
179
180/// 2D interpolation class
181#[pyclass(name = "Interp2d")]
182pub struct PyInterp2d {
183    interp: Interp2d<f64>,
184}
185
186#[pymethods]
187impl PyInterp2d {
188    /// Create a new 2D interpolator
189    ///
190    /// Parameters:
191    /// - x: x coordinates (must be sorted)
192    /// - y: y coordinates (must be sorted)
193    /// - z: z values with shape (len(y), len(x))
194    /// - kind: interpolation method ('linear', 'cubic', 'quintic')
195    #[new]
196    #[pyo3(signature = (x, y, z, kind="linear"))]
197    fn new(
198        x: PyReadonlyArray1<f64>,
199        y: PyReadonlyArray1<f64>,
200        z: PyReadonlyArray2<f64>,
201        kind: &str,
202    ) -> PyResult<Self> {
203        let x_vec: Vec<f64> = x.as_array().to_vec();
204        let y_vec: Vec<f64> = y.as_array().to_vec();
205        let z_arr = z.as_array();
206
207        let x_arr = Array1_17::from_vec(x_vec);
208        let y_arr = Array1_17::from_vec(y_vec);
209
210        // Convert to Array2_17
211        let z_shape = z_arr.shape();
212        let z_vec: Vec<f64> = z_arr.iter().copied().collect();
213        let z_arr_17 = Array2_17::from_shape_vec((z_shape[0], z_shape[1]), z_vec).map_err(|e| {
214            pyo3::exceptions::PyValueError::new_err(format!("Invalid z array: {e}"))
215        })?;
216
217        let interp_kind = match kind.to_lowercase().as_str() {
218            "linear" => Interp2dKind::Linear,
219            "cubic" => Interp2dKind::Cubic,
220            "quintic" => Interp2dKind::Quintic,
221            _ => Interp2dKind::Linear,
222        };
223
224        let interp = Interp2d::new(&x_arr.view(), &y_arr.view(), &z_arr_17.view(), interp_kind)
225            .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
226
227        Ok(PyInterp2d { interp })
228    }
229
230    /// Evaluate at a single point (x, y)
231    fn __call__(&self, x: f64, y: f64) -> PyResult<f64> {
232        self.interp
233            .evaluate(x, y)
234            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))
235    }
236
237    /// Evaluate at multiple points
238    ///
239    /// Parameters:
240    /// - x_new: x coordinates
241    /// - y_new: y coordinates (must have same length as x_new)
242    fn eval_array(
243        &self,
244        py: Python,
245        x_new: PyReadonlyArray1<f64>,
246        y_new: PyReadonlyArray1<f64>,
247    ) -> PyResult<Py<PyArray1<f64>>> {
248        let x_vec: Vec<f64> = x_new.as_array().to_vec();
249        let y_vec: Vec<f64> = y_new.as_array().to_vec();
250        let x_arr = Array1_17::from_vec(x_vec);
251        let y_arr = Array1_17::from_vec(y_vec);
252
253        let result = self
254            .interp
255            .evaluate_array(&x_arr.view(), &y_arr.view())
256            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
257
258        scirs_to_numpy_array1(Array1::from_vec(result.to_vec()), py)
259    }
260
261    /// Evaluate on a regular grid
262    ///
263    /// Parameters:
264    /// - x_new: x coordinates for output grid
265    /// - y_new: y coordinates for output grid
266    ///
267    /// Returns:
268    /// - 2D array with shape (len(y_new), len(x_new))
269    fn eval_grid(
270        &self,
271        py: Python,
272        x_new: PyReadonlyArray1<f64>,
273        y_new: PyReadonlyArray1<f64>,
274    ) -> PyResult<Py<PyArray2<f64>>> {
275        let x_vec: Vec<f64> = x_new.as_array().to_vec();
276        let y_vec: Vec<f64> = y_new.as_array().to_vec();
277        let x_arr = Array1_17::from_vec(x_vec);
278        let y_arr = Array1_17::from_vec(y_vec);
279
280        let result = self
281            .interp
282            .evaluate_grid(&x_arr.view(), &y_arr.view())
283            .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
284
285        // Convert Array2_17 to Array2 (ndarray 0.16)
286        let shape = result.dim();
287        let vec: Vec<f64> = result.into_iter().collect();
288        scirs_to_numpy_array2(
289            Array2::from_shape_vec(shape, vec).expect("Operation failed"),
290            py,
291        )
292    }
293}
294
295// =============================================================================
296// Simple interpolation functions
297// =============================================================================
298
299/// Linear interpolation - optimized direct implementation
300#[pyfunction]
301fn interp_py(
302    py: Python,
303    x: PyReadonlyArray1<f64>,
304    xp: PyReadonlyArray1<f64>,
305    fp: PyReadonlyArray1<f64>,
306) -> PyResult<Py<PyArray1<f64>>> {
307    let x_arr = x.as_array();
308    let xp_arr = xp.as_array();
309    let fp_arr = fp.as_array();
310
311    let n = xp_arr.len();
312    if n == 0 {
313        return Err(pyo3::exceptions::PyValueError::new_err(
314            "xp must not be empty",
315        ));
316    }
317    if n != fp_arr.len() {
318        return Err(pyo3::exceptions::PyValueError::new_err(
319            "xp and fp must have same length",
320        ));
321    }
322
323    let xp_slice = xp_arr.as_slice().expect("Operation failed");
324    let fp_slice = fp_arr.as_slice().expect("Operation failed");
325
326    // Pre-allocate result
327    let mut result = Vec::with_capacity(x_arr.len());
328
329    for &xi in x_arr.iter() {
330        let yi = if xi <= xp_slice[0] {
331            fp_slice[0]
332        } else if xi >= xp_slice[n - 1] {
333            fp_slice[n - 1]
334        } else {
335            // Binary search for interval
336            let idx = xp_slice.partition_point(|&v| v < xi);
337            let i = if idx > 0 { idx - 1 } else { 0 };
338
339            // Linear interpolation
340            let x0 = xp_slice[i];
341            let x1 = xp_slice[i + 1];
342            let y0 = fp_slice[i];
343            let y1 = fp_slice[i + 1];
344            let t = (xi - x0) / (x1 - x0);
345            y0 + t * (y1 - y0)
346        };
347        result.push(yi);
348    }
349
350    scirs_to_numpy_array1(Array1::from_vec(result), py)
351}
352
353/// Piecewise linear interpolation with boundary handling
354#[pyfunction]
355#[pyo3(signature = (x, xp, fp, left=None, right=None))]
356fn interp_with_bounds_py(
357    py: Python,
358    x: PyReadonlyArray1<f64>,
359    xp: PyReadonlyArray1<f64>,
360    fp: PyReadonlyArray1<f64>,
361    left: Option<f64>,
362    right: Option<f64>,
363) -> PyResult<Py<PyArray1<f64>>> {
364    let x_arr = x.as_array();
365    let xp_arr = xp.as_array();
366    let fp_arr = fp.as_array();
367
368    let n = xp_arr.len();
369    if n == 0 || fp_arr.len() != n {
370        return Err(pyo3::exceptions::PyValueError::new_err(
371            "Invalid input arrays",
372        ));
373    }
374
375    let mut result = Vec::with_capacity(x_arr.len());
376
377    for &xi in x_arr.iter() {
378        let yi = if xi < xp_arr[0] {
379            left.unwrap_or(fp_arr[0])
380        } else if xi > xp_arr[n - 1] {
381            right.unwrap_or(fp_arr[n - 1])
382        } else {
383            // Binary search for interval
384            let mut lo = 0;
385            let mut hi = n - 1;
386            while hi - lo > 1 {
387                let mid = (lo + hi) / 2;
388                if xp_arr[mid] <= xi {
389                    lo = mid;
390                } else {
391                    hi = mid;
392                }
393            }
394            // Linear interpolation
395            let t = (xi - xp_arr[lo]) / (xp_arr[hi] - xp_arr[lo]);
396            fp_arr[lo] * (1.0 - t) + fp_arr[hi] * t
397        };
398        result.push(yi);
399    }
400
401    scirs_to_numpy_array1(Array1::from_vec(result), py)
402}
403
404/// Python module registration
405pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
406    m.add_class::<PyInterp1d>()?;
407    m.add_class::<PyCubicSpline>()?;
408    m.add_class::<PyInterp2d>()?;
409    m.add_function(wrap_pyfunction!(interp_py, m)?)?;
410    m.add_function(wrap_pyfunction!(interp_with_bounds_py, m)?)?;
411
412    Ok(())
413}