Skip to main content

scivex_optim/interpolate/
linear.rs

1//! Piecewise linear interpolation.
2
3use scivex_core::Float;
4
5use crate::error::{OptimError, Result};
6
7use super::{Extrapolate, find_interval, validate_finite, validate_sorted};
8
9/// Piecewise linear 1-D interpolator.
10///
11/// Constructed from sorted `(x, y)` data. Each query is evaluated via binary
12/// search + linear blend in O(log n).
13#[derive(Debug, Clone)]
14pub struct Linear1d<T: Float> {
15    xs: Vec<T>,
16    ys: Vec<T>,
17    extrap: Extrapolate,
18}
19
20impl<T: Float> Linear1d<T> {
21    /// Create a new linear interpolator.
22    ///
23    /// # Examples
24    ///
25    /// ```
26    /// # use scivex_optim::interpolate::{Linear1d, Extrapolate};
27    /// let xs = [0.0_f64, 1.0, 2.0];
28    /// let ys = [0.0_f64, 2.0, 4.0];
29    /// let interp = Linear1d::new(&xs, &ys, Extrapolate::Error).unwrap();
30    /// let y = interp.eval(0.5).unwrap();
31    /// assert!((y - 1.0).abs() < 1e-10);
32    /// ```
33    ///
34    /// # Errors
35    ///
36    /// - `xs` and `ys` must have the same length (>= 2).
37    /// - `xs` must be strictly increasing.
38    /// - All values must be finite.
39    pub fn new(xs: &[T], ys: &[T], extrap: Extrapolate) -> Result<Self> {
40        if xs.len() != ys.len() {
41            return Err(OptimError::InvalidParameter {
42                name: "ys",
43                reason: "xs and ys must have the same length",
44            });
45        }
46        validate_sorted(xs, 2)?;
47        validate_finite(xs, "xs")?;
48        validate_finite(ys, "ys")?;
49
50        Ok(Self {
51            xs: xs.to_vec(),
52            ys: ys.to_vec(),
53            extrap,
54        })
55    }
56
57    /// Evaluate the interpolant at a single point.
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// # use scivex_optim::interpolate::{Linear1d, Extrapolate};
63    /// let interp = Linear1d::new(&[0.0_f64, 1.0, 2.0], &[0.0_f64, 1.0, 4.0], Extrapolate::Error).unwrap();
64    /// let y = interp.eval(1.5).unwrap();
65    /// assert!((y - 2.5).abs() < 1e-10);
66    /// ```
67    pub fn eval(&self, x: T) -> Result<T> {
68        let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
69        let dx = self.xs[i + 1] - self.xs[i];
70        let t = (xq - self.xs[i]) / dx;
71        Ok(self.ys[i] * (T::one() - t) + self.ys[i + 1] * t)
72    }
73
74    /// Evaluate the interpolant at many points.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// # use scivex_optim::interpolate::{Linear1d, Extrapolate};
80    /// let interp = Linear1d::new(&[0.0_f64, 1.0, 2.0], &[0.0, 2.0, 4.0], Extrapolate::Error).unwrap();
81    /// let ys = interp.eval_many(&[0.5, 1.5]).unwrap();
82    /// assert!((ys[0] - 1.0).abs() < 1e-10);
83    /// assert!((ys[1] - 3.0).abs() < 1e-10);
84    /// ```
85    pub fn eval_many(&self, xs: &[T]) -> Result<Vec<T>> {
86        xs.iter().map(|&x| self.eval(x)).collect()
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_linear_exact() {
96        // y = 2x on [0,3]
97        let xs = [0.0, 1.0, 2.0, 3.0];
98        let ys = [0.0, 2.0, 4.0, 6.0];
99        let interp = Linear1d::new(&xs, &ys, Extrapolate::Error).unwrap();
100
101        for &x in &[0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] {
102            let y = interp.eval(x).unwrap();
103            assert!((y - 2.0 * x).abs() < 1e-12, "x={x}, y={y}");
104        }
105    }
106
107    #[test]
108    fn test_linear_midpoint() {
109        let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
110        let y = interp.eval(1.5).unwrap();
111        assert!((y - 2.5).abs() < 1e-12);
112    }
113
114    #[test]
115    fn test_linear_eval_many() {
116        let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
117        let result = interp.eval_many(&[0.0, 1.0, 2.0]).unwrap();
118        assert!(result[0].abs() < 1e-12);
119        assert!((result[1] - 1.0).abs() < 1e-12);
120        assert!((result[2] - 4.0).abs() < 1e-12);
121    }
122
123    #[test]
124    fn test_linear_out_of_range_error() {
125        let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
126        assert!(interp.eval(-0.5).is_err());
127        assert!(interp.eval(2.5).is_err());
128    }
129
130    #[test]
131    fn test_linear_clamp() {
132        let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[1.0, 3.0, 7.0], Extrapolate::Clamp).unwrap();
133        assert!((interp.eval(-1.0).unwrap() - 1.0).abs() < 1e-12);
134        assert!((interp.eval(5.0).unwrap() - 7.0).abs() < 1e-12);
135    }
136
137    #[test]
138    fn test_linear_unsorted_error() {
139        assert!(Linear1d::new(&[0.0, 2.0, 1.0], &[0.0, 1.0, 2.0], Extrapolate::Error).is_err());
140    }
141
142    #[test]
143    fn test_linear_f32() {
144        let interp = Linear1d::new(
145            &[0.0_f32, 1.0, 2.0],
146            &[0.0_f32, 1.0, 4.0],
147            Extrapolate::Error,
148        )
149        .unwrap();
150        let y = interp.eval(1.5_f32).unwrap();
151        assert!((y - 2.5_f32).abs() < 1e-6);
152    }
153}