Skip to main content

scivex_optim/interpolate/
mod.rs

1//! Interpolation algorithms (1-D and 2-D).
2//!
3//! Provides piecewise linear, cubic spline, B-spline, bilinear, and bicubic
4//! interpolation. All interpolators are constructed once (precomputing
5//! coefficients) and then evaluated many times.
6//!
7//! # Quick start
8//!
9//! ```ignore
10//! use scivex_optim::interpolate::{Linear1d, Extrapolate};
11//!
12//! let xs = [0.0, 1.0, 2.0, 3.0];
13//! let ys = [0.0, 1.0, 4.0, 9.0];
14//! let interp = Linear1d::new(&xs, &ys, Extrapolate::Error).unwrap();
15//! let y = interp.eval(1.5).unwrap(); // 2.5
16//! ```
17
18mod bicubic;
19mod bilinear;
20mod bspline;
21mod cubic_spline;
22mod linear;
23pub(super) mod thomas;
24
25pub use bicubic::Bicubic2d;
26pub use bilinear::Bilinear2d;
27pub use bspline::BSpline;
28pub use cubic_spline::CubicSpline;
29pub use linear::Linear1d;
30
31use scivex_core::Float;
32
33use crate::error::{OptimError, Result};
34
35// ---------------------------------------------------------------------------
36// Public types
37// ---------------------------------------------------------------------------
38
39/// Method selector for 1-D convenience function [`interp1d`].
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Interp1dMethod {
42    /// Piecewise linear interpolation.
43    Linear,
44    /// Natural cubic spline interpolation.
45    CubicSpline,
46    /// Uniform B-spline interpolation (degree 3).
47    BSpline,
48}
49
50/// Method selector for 2-D convenience function [`interp2d`].
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum Interp2dMethod {
53    /// Bilinear interpolation on a rectilinear grid.
54    Bilinear,
55    /// Bicubic interpolation on a rectilinear grid.
56    Bicubic,
57}
58
59/// Boundary condition for cubic spline construction.
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum SplineBoundary<T> {
62    /// Natural boundary: second derivative is zero at both endpoints.
63    Natural,
64    /// Clamped boundary: first derivative is prescribed at both endpoints.
65    Clamped { left: T, right: T },
66}
67
68/// Controls behaviour when an evaluation point lies outside the data range.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
70pub enum Extrapolate {
71    /// Return an error (default).
72    #[default]
73    Error,
74    /// Clamp the query to the nearest boundary value.
75    Clamp,
76    /// Extend the nearest segment/polynomial beyond the boundary.
77    Extend,
78}
79
80// ---------------------------------------------------------------------------
81// Shared helpers
82// ---------------------------------------------------------------------------
83
84/// Binary search for the interval index `i` such that `xs[i] <= x < xs[i+1]`.
85///
86/// Returns `Ok(i)` on success.
87/// If `x` is outside `[xs[0], xs[n-1]]`, behaviour depends on `extrap`:
88///
89/// - `Extrapolate::Error` — returns `Err(InvalidParameter)`.
90/// - `Extrapolate::Clamp` — clamps `x` to the boundary and returns the
91///   boundary interval.
92/// - `Extrapolate::Extend` — returns the first or last interval index.
93#[inline]
94pub(crate) fn find_interval<T: Float>(xs: &[T], x: T, extrap: Extrapolate) -> Result<(usize, T)> {
95    debug_assert!(xs.len() >= 2);
96    let n = xs.len();
97
98    if x < xs[0] {
99        return match extrap {
100            Extrapolate::Error => Err(OptimError::InvalidParameter {
101                name: "x",
102                reason: "query point is below data range",
103            }),
104            Extrapolate::Clamp => Ok((0, xs[0])),
105            Extrapolate::Extend => Ok((0, x)),
106        };
107    }
108
109    if x > xs[n - 1] {
110        return match extrap {
111            Extrapolate::Error => Err(OptimError::InvalidParameter {
112                name: "x",
113                reason: "query point is above data range",
114            }),
115            Extrapolate::Clamp => Ok((n - 2, xs[n - 1])),
116            Extrapolate::Extend => Ok((n - 2, x)),
117        };
118    }
119
120    // Exact match on the last knot belongs to the last interval.
121    if x == xs[n - 1] {
122        return Ok((n - 2, x));
123    }
124
125    // Standard binary search
126    let mut lo: usize = 0;
127    let mut hi: usize = n - 1;
128    while hi - lo > 1 {
129        let mid = lo + (hi - lo) / 2;
130        if xs[mid] <= x {
131            lo = mid;
132        } else {
133            hi = mid;
134        }
135    }
136
137    Ok((lo, x))
138}
139
140/// Validate that `xs` is strictly increasing and has at least `min_len` points.
141pub(crate) fn validate_sorted<T: Float>(xs: &[T], min_len: usize) -> Result<()> {
142    if xs.len() < min_len {
143        return Err(OptimError::InvalidParameter {
144            name: "xs",
145            reason: "not enough data points",
146        });
147    }
148    for i in 1..xs.len() {
149        if xs[i] <= xs[i - 1] {
150            return Err(OptimError::InvalidParameter {
151                name: "xs",
152                reason: "knots must be strictly increasing",
153            });
154        }
155    }
156    Ok(())
157}
158
159/// Validate that no values are NaN or infinite.
160pub(crate) fn validate_finite<T: Float>(vals: &[T], name: &'static str) -> Result<()> {
161    for &v in vals {
162        if !v.is_finite() {
163            return Err(OptimError::NonFiniteValue { context: name });
164        }
165    }
166    Ok(())
167}
168
169// ---------------------------------------------------------------------------
170// Convenience functions
171// ---------------------------------------------------------------------------
172
173/// One-shot 1-D interpolation: build an interpolator and evaluate at `query`.
174///
175/// For repeated evaluation, prefer constructing the interpolator directly.
176pub fn interp1d<T: Float>(
177    xs: &[T],
178    ys: &[T],
179    query: &[T],
180    method: Interp1dMethod,
181) -> Result<Vec<T>> {
182    match method {
183        Interp1dMethod::Linear => {
184            let interp = Linear1d::new(xs, ys, Extrapolate::Error)?;
185            interp.eval_many(query)
186        }
187        Interp1dMethod::CubicSpline => {
188            let interp = CubicSpline::new(xs, ys, SplineBoundary::Natural, Extrapolate::Error)?;
189            interp.eval_many(query)
190        }
191        Interp1dMethod::BSpline => {
192            let interp = BSpline::fit(xs, ys, 3, Extrapolate::Error)?;
193            interp.eval_many(query)
194        }
195    }
196}
197
198/// One-shot 2-D interpolation: build an interpolator and evaluate at `query`.
199///
200/// `query` is a slice of `(x, y)` pairs.
201pub fn interp2d<T: Float>(
202    xs: Vec<T>,
203    ys: Vec<T>,
204    zs: Vec<Vec<T>>,
205    query: &[(T, T)],
206    method: Interp2dMethod,
207) -> Result<Vec<T>> {
208    match method {
209        Interp2dMethod::Bilinear => {
210            let interp = Bilinear2d::new(xs, ys, zs, Extrapolate::Error)?;
211            interp.eval_many(query)
212        }
213        Interp2dMethod::Bicubic => {
214            let interp = Bicubic2d::new(xs, ys, &zs, Extrapolate::Error)?;
215            interp.eval_many(query)
216        }
217    }
218}
219
220#[cfg(test)]
221#[allow(clippy::float_cmp)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_find_interval_basic() {
227        let xs = [0.0, 1.0, 2.0, 3.0];
228        let (i, x) = find_interval(&xs, 1.5, Extrapolate::Error).unwrap();
229        assert_eq!(i, 1);
230        assert!((x - 1.5).abs() < 1e-15);
231    }
232
233    #[test]
234    fn test_find_interval_last_point() {
235        let xs = [0.0, 1.0, 2.0, 3.0];
236        let (i, x) = find_interval(&xs, 3.0, Extrapolate::Error).unwrap();
237        assert_eq!(i, 2);
238        assert!((x - 3.0).abs() < 1e-15);
239    }
240
241    #[test]
242    fn test_find_interval_error_below() {
243        let xs = [0.0, 1.0, 2.0];
244        let res = find_interval(&xs, -0.1, Extrapolate::Error);
245        assert!(res.is_err());
246    }
247
248    #[test]
249    fn test_find_interval_clamp_above() {
250        let xs = [0.0, 1.0, 2.0];
251        let (i, x) = find_interval(&xs, 5.0, Extrapolate::Clamp).unwrap();
252        assert_eq!(i, 1);
253        assert!((x - 2.0).abs() < 1e-15);
254    }
255
256    #[test]
257    fn test_find_interval_extend_below() {
258        let xs = [0.0, 1.0, 2.0];
259        let (i, x) = find_interval(&xs, -1.0, Extrapolate::Extend).unwrap();
260        assert_eq!(i, 0);
261        assert!((x - (-1.0)).abs() < 1e-15);
262    }
263
264    #[test]
265    fn test_validate_sorted_ok() {
266        assert!(validate_sorted(&[0.0, 1.0, 2.0], 2).is_ok());
267    }
268
269    #[test]
270    fn test_validate_sorted_too_few() {
271        assert!(validate_sorted(&[0.0_f64], 2).is_err());
272    }
273
274    #[test]
275    fn test_validate_sorted_not_increasing() {
276        assert!(validate_sorted(&[0.0, 2.0, 1.0], 2).is_err());
277    }
278
279    #[test]
280    fn test_interp1d_linear() {
281        let result = interp1d(
282            &[0.0, 1.0, 2.0],
283            &[0.0, 2.0, 4.0],
284            &[0.5, 1.5],
285            Interp1dMethod::Linear,
286        )
287        .unwrap();
288        assert!((result[0] - 1.0).abs() < 1e-12);
289        assert!((result[1] - 3.0).abs() < 1e-12);
290    }
291
292    #[test]
293    fn test_interp1d_cubic_spline() {
294        let result = interp1d(
295            &[0.0, 1.0, 2.0, 3.0],
296            &[0.0, 1.0, 4.0, 9.0],
297            &[1.0, 2.0],
298            Interp1dMethod::CubicSpline,
299        )
300        .unwrap();
301        assert!((result[0] - 1.0).abs() < 1e-10);
302        assert!((result[1] - 4.0).abs() < 1e-10);
303    }
304
305    #[test]
306    fn test_interp1d_bspline() {
307        let result = interp1d(
308            &[0.0, 1.0, 2.0, 3.0, 4.0],
309            &[0.0, 1.0, 4.0, 9.0, 16.0],
310            &[2.0],
311            Interp1dMethod::BSpline,
312        )
313        .unwrap();
314        assert!((result[0] - 4.0).abs() < 1e-6);
315    }
316
317    #[test]
318    fn test_interp2d_bilinear() {
319        let xs = vec![0.0, 1.0];
320        let ys = vec![0.0, 1.0];
321        let zs = vec![vec![0.0, 2.0], vec![1.0, 3.0]]; // z = x + 2y
322        let result = interp2d(xs, ys, zs, &[(0.5, 0.5)], Interp2dMethod::Bilinear).unwrap();
323        assert!((result[0] - 1.5).abs() < 1e-12);
324    }
325
326    #[test]
327    fn test_interp2d_bicubic() {
328        let xs = vec![0.0, 1.0, 2.0, 3.0];
329        let ys = vec![0.0, 1.0, 2.0, 3.0];
330        let zs: Vec<Vec<f64>> = (0..4)
331            .map(|i| (0..4).map(|j| f64::from(i) + 2.0 * f64::from(j)).collect())
332            .collect();
333        let result = interp2d(xs, ys, zs, &[(1.5, 1.5)], Interp2dMethod::Bicubic).unwrap();
334        assert!((result[0] - 4.5).abs() < 1e-10);
335    }
336}