Skip to main content

scivex_optim/interpolate/
cubic_spline.rs

1//! Natural and clamped cubic spline interpolation.
2
3use scivex_core::Float;
4
5use crate::error::{OptimError, Result};
6
7use super::thomas::thomas_solve;
8use super::{Extrapolate, SplineBoundary, find_interval, validate_finite, validate_sorted};
9
10/// Cubic spline 1-D interpolator.
11///
12/// Precomputes polynomial coefficients for each segment at construction time.
13/// Each evaluation is O(log n) via binary search.
14#[derive(Debug, Clone)]
15pub struct CubicSpline<T: Float> {
16    xs: Vec<T>,
17    /// Coefficients `(a, b, c, d)` per segment: `S_i(x) = a + b*(x-x_i) + c*(x-x_i)^2 + d*(x-x_i)^3`
18    coeffs: Vec<(T, T, T, T)>,
19    extrap: Extrapolate,
20}
21
22impl<T: Float> CubicSpline<T> {
23    /// Construct a cubic spline interpolator.
24    ///
25    /// # Errors
26    ///
27    /// - `xs` and `ys` must have the same length (>= 3).
28    /// - `xs` must be strictly increasing.
29    /// - All values must be finite.
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// # use scivex_optim::interpolate::{CubicSpline, SplineBoundary, Extrapolate};
35    /// let xs = [0.0_f64, 1.0, 2.0, 3.0];
36    /// let ys = [0.0, 1.0, 4.0, 9.0]; // roughly x²
37    /// let spline = CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
38    /// let y = spline.eval(1.5).unwrap();
39    /// assert!((y - 2.25).abs() < 0.5); // close to 1.5²
40    /// ```
41    pub fn new(
42        xs: &[T],
43        ys: &[T],
44        boundary: SplineBoundary<T>,
45        extrap: Extrapolate,
46    ) -> Result<Self> {
47        if xs.len() != ys.len() {
48            return Err(OptimError::InvalidParameter {
49                name: "ys",
50                reason: "xs and ys must have the same length",
51            });
52        }
53        validate_sorted(xs, 3)?;
54        validate_finite(xs, "xs")?;
55        validate_finite(ys, "ys")?;
56
57        let n = xs.len();
58
59        // Step widths and divided differences
60        let h: Vec<T> = (0..n - 1).map(|i| xs[i + 1] - xs[i]).collect();
61        let delta: Vec<T> = (0..n - 1).map(|i| (ys[i + 1] - ys[i]) / h[i]).collect();
62
63        let coeffs = match boundary {
64            SplineBoundary::Natural => Self::solve_natural(ys, &h, &delta)?,
65            SplineBoundary::Clamped { left, right } => {
66                Self::solve_clamped(n, ys, &h, &delta, left, right)?
67            }
68        };
69
70        Ok(Self {
71            xs: xs.to_vec(),
72            coeffs,
73            extrap,
74        })
75    }
76
77    fn solve_natural(ys: &[T], h: &[T], delta: &[T]) -> Result<Vec<(T, T, T, T)>> {
78        let n = ys.len();
79        let two = T::from_f64(2.0);
80        let six = T::from_f64(6.0);
81
82        if n == 3 {
83            let three = T::from_f64(3.0);
84            let m1 = three * (delta[1] - delta[0]) / (h[0] + h[1]);
85            let m = vec![T::zero(), m1, T::zero()];
86            return Ok(Self::build_coeffs(ys, h, &m, two, six));
87        }
88
89        let size = n - 2;
90        let mut sub = Vec::with_capacity(size - 1);
91        let mut diag = Vec::with_capacity(size);
92        let mut sup = Vec::with_capacity(size - 1);
93        let mut rhs = Vec::with_capacity(size);
94
95        for i in 0..size {
96            let row = i + 1;
97            diag.push(two * (h[row - 1] + h[row]));
98            rhs.push(six * (delta[row] - delta[row - 1]));
99            if i > 0 {
100                sub.push(h[row - 1]);
101            }
102            if i < size - 1 {
103                sup.push(h[row]);
104            }
105        }
106
107        let m_interior = thomas_solve(&sub, &diag, &sup, &rhs)?;
108
109        let mut m = Vec::with_capacity(n);
110        m.push(T::zero());
111        m.extend_from_slice(&m_interior);
112        m.push(T::zero());
113
114        Ok(Self::build_coeffs(ys, h, &m, two, six))
115    }
116
117    fn solve_clamped(
118        n: usize,
119        ys: &[T],
120        h: &[T],
121        delta: &[T],
122        left_deriv: T,
123        right_deriv: T,
124    ) -> Result<Vec<(T, T, T, T)>> {
125        let two = T::from_f64(2.0);
126        let six = T::from_f64(6.0);
127
128        let mut sub = Vec::with_capacity(n - 1);
129        let mut diag = Vec::with_capacity(n);
130        let mut sup = Vec::with_capacity(n - 1);
131        let mut rhs = Vec::with_capacity(n);
132
133        // First row
134        diag.push(two * h[0]);
135        sup.push(h[0]);
136        rhs.push(six * (delta[0] - left_deriv));
137
138        // Interior rows
139        for i in 1..n - 1 {
140            sub.push(h[i - 1]);
141            diag.push(two * (h[i - 1] + h[i]));
142            sup.push(h[i]);
143            rhs.push(six * (delta[i] - delta[i - 1]));
144        }
145
146        // Last row
147        sub.push(h[n - 2]);
148        diag.push(two * h[n - 2]);
149        rhs.push(six * (right_deriv - delta[n - 2]));
150
151        let m = thomas_solve(&sub, &diag, &sup, &rhs)?;
152
153        Ok(Self::build_coeffs(ys, h, &m, two, six))
154    }
155
156    fn build_coeffs(ys: &[T], h: &[T], m: &[T], two: T, six: T) -> Vec<(T, T, T, T)> {
157        let nm1 = h.len();
158        let mut coeffs = Vec::with_capacity(nm1);
159        for i in 0..nm1 {
160            let a = ys[i];
161            let b = (ys[i + 1] - ys[i]) / h[i] - h[i] * (two * m[i] + m[i + 1]) / six;
162            let c = m[i] / two;
163            let d = (m[i + 1] - m[i]) / (six * h[i]);
164            coeffs.push((a, b, c, d));
165        }
166        coeffs
167    }
168
169    /// Evaluate the spline at a single point.
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// # use scivex_optim::interpolate::{CubicSpline, SplineBoundary, Extrapolate};
175    /// let spline = CubicSpline::new(
176    ///     &[0.0_f64, 1.0, 2.0, 3.0], &[0.0, 1.0, 4.0, 9.0],
177    ///     SplineBoundary::Natural, Extrapolate::Error,
178    /// ).unwrap();
179    /// let y = spline.eval(1.5).unwrap();
180    /// assert!((y - 2.25).abs() < 0.5);
181    /// ```
182    pub fn eval(&self, x: T) -> Result<T> {
183        let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
184        let dx = xq - self.xs[i];
185        let (a, b, c, d) = self.coeffs[i];
186        Ok(a + dx * (b + dx * (c + dx * d)))
187    }
188
189    /// Evaluate at many points.
190    ///
191    /// # Examples
192    ///
193    /// ```
194    /// # use scivex_optim::interpolate::{CubicSpline, SplineBoundary, Extrapolate};
195    /// let spline = CubicSpline::new(
196    ///     &[0.0_f64, 1.0, 2.0, 3.0], &[0.0, 1.0, 4.0, 9.0],
197    ///     SplineBoundary::Natural, Extrapolate::Error,
198    /// ).unwrap();
199    /// let ys = spline.eval_many(&[0.5, 1.5, 2.5]).unwrap();
200    /// assert_eq!(ys.len(), 3);
201    /// ```
202    pub fn eval_many(&self, xs: &[T]) -> Result<Vec<T>> {
203        xs.iter().map(|&x| self.eval(x)).collect()
204    }
205
206    /// Evaluate the first derivative at a single point.
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// # use scivex_optim::interpolate::{CubicSpline, SplineBoundary, Extrapolate};
212    /// let spline = CubicSpline::new(
213    ///     &[0.0_f64, 1.0, 2.0, 3.0], &[0.0, 1.0, 4.0, 9.0],
214    ///     SplineBoundary::Natural, Extrapolate::Error,
215    /// ).unwrap();
216    /// let dy = spline.derivative(1.0).unwrap();
217    /// assert!(dy > 0.0); // increasing function
218    /// ```
219    pub fn derivative(&self, x: T) -> Result<T> {
220        let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
221        let dx = xq - self.xs[i];
222        let (_, b, c, d) = self.coeffs[i];
223        let two = T::from_f64(2.0);
224        let three = T::from_f64(3.0);
225        Ok(b + dx * (two * c + three * d * dx))
226    }
227
228    /// Evaluate the second derivative at a single point.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// # use scivex_optim::interpolate::{CubicSpline, SplineBoundary, Extrapolate};
234    /// let spline = CubicSpline::new(
235    ///     &[0.0_f64, 1.0, 2.0, 3.0], &[0.0, 1.0, 4.0, 9.0],
236    ///     SplineBoundary::Natural, Extrapolate::Error,
237    /// ).unwrap();
238    /// let d2y = spline.second_derivative(0.0).unwrap();
239    /// // Natural boundary: second derivative = 0 at endpoints
240    /// assert!(d2y.abs() < 1e-10);
241    /// ```
242    pub fn second_derivative(&self, x: T) -> Result<T> {
243        let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
244        let dx = xq - self.xs[i];
245        let (_, _, c, d) = self.coeffs[i];
246        let two = T::from_f64(2.0);
247        let six = T::from_f64(6.0);
248        Ok(two * c + six * d * dx)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_cubic_spline_reproduces_data() {
258        let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
259        let ys = [0.0, 1.0, 4.0, 9.0, 16.0];
260        let spline =
261            CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
262        for (i, &x) in xs.iter().enumerate() {
263            let y = spline.eval(x).unwrap();
264            assert!(
265                (y - ys[i]).abs() < 1e-10,
266                "at x={x}: got {y}, expected {}",
267                ys[i]
268            );
269        }
270    }
271
272    #[test]
273    fn test_cubic_spline_natural_boundary() {
274        let spline = CubicSpline::new(
275            &[0.0, 1.0, 2.0, 3.0, 4.0],
276            &[0.0, 0.5, 2.0, 1.5, 0.0],
277            SplineBoundary::Natural,
278            Extrapolate::Error,
279        )
280        .unwrap();
281        let sd_left = spline.second_derivative(0.0).unwrap();
282        let sd_right = spline.second_derivative(4.0).unwrap();
283        assert!(sd_left.abs() < 1e-10, "left 2nd deriv = {sd_left}");
284        assert!(sd_right.abs() < 1e-10, "right 2nd deriv = {sd_right}");
285    }
286
287    #[test]
288    fn test_cubic_spline_clamped() {
289        let spline = CubicSpline::new(
290            &[0.0, 1.0, 2.0, 3.0],
291            &[0.0, 1.0, 4.0, 9.0],
292            SplineBoundary::Clamped {
293                left: 0.0,
294                right: 6.0,
295            },
296            Extrapolate::Error,
297        )
298        .unwrap();
299
300        let d_left = spline.derivative(0.0).unwrap();
301        let d_right = spline.derivative(3.0).unwrap();
302        assert!(d_left.abs() < 1e-8, "left derivative = {d_left}");
303        assert!((d_right - 6.0).abs() < 1e-8, "right derivative = {d_right}");
304    }
305
306    #[test]
307    fn test_cubic_spline_exact_for_cubic() {
308        let xs = [0.0_f64, 1.0, 2.0, 3.0];
309        let ys: Vec<f64> = xs.iter().map(|&x| x * x * x).collect();
310        let spline = CubicSpline::new(
311            &xs,
312            &ys,
313            SplineBoundary::Clamped {
314                left: 0.0,
315                right: 27.0,
316            },
317            Extrapolate::Error,
318        )
319        .unwrap();
320
321        for x in [0.5, 1.0, 1.5, 2.0, 2.5] {
322            let y = spline.eval(x).unwrap();
323            let expected = x * x * x;
324            assert!(
325                (y - expected).abs() < 1e-8,
326                "at x={x}: got {y}, expected {expected}"
327            );
328        }
329    }
330
331    #[test]
332    fn test_cubic_spline_derivative() {
333        let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
334        let ys: Vec<f64> = xs.iter().map(|&x| x * x).collect();
335        let spline =
336            CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
337        let d = spline.derivative(2.0).unwrap();
338        assert!((d - 4.0).abs() < 0.5, "derivative at 2.0 = {d}");
339    }
340
341    #[test]
342    fn test_cubic_spline_continuity() {
343        let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
344        let ys = [0.0, 1.0, 0.0, 1.0, 0.0];
345        let spline =
346            CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
347
348        let eps = 1e-8;
349        for &x in &xs[1..xs.len() - 1] {
350            let left = spline.eval(x - eps).unwrap();
351            let right = spline.eval(x).unwrap();
352            assert!(
353                (left - right).abs() < 1e-5,
354                "discontinuity at x={x}: left={left}, right={right}"
355            );
356        }
357    }
358}