1use scivex_core::Float;
4
5use crate::error::{OptimError, Result};
6
7use super::thomas::thomas_solve;
8use super::{Extrapolate, SplineBoundary, find_interval, validate_finite, validate_sorted};
9
10#[derive(Debug, Clone)]
15pub struct CubicSpline<T: Float> {
16 xs: Vec<T>,
17 coeffs: Vec<(T, T, T, T)>,
19 extrap: Extrapolate,
20}
21
22impl<T: Float> CubicSpline<T> {
23 pub fn new(
42 xs: &[T],
43 ys: &[T],
44 boundary: SplineBoundary<T>,
45 extrap: Extrapolate,
46 ) -> Result<Self> {
47 if xs.len() != ys.len() {
48 return Err(OptimError::InvalidParameter {
49 name: "ys",
50 reason: "xs and ys must have the same length",
51 });
52 }
53 validate_sorted(xs, 3)?;
54 validate_finite(xs, "xs")?;
55 validate_finite(ys, "ys")?;
56
57 let n = xs.len();
58
59 let h: Vec<T> = (0..n - 1).map(|i| xs[i + 1] - xs[i]).collect();
61 let delta: Vec<T> = (0..n - 1).map(|i| (ys[i + 1] - ys[i]) / h[i]).collect();
62
63 let coeffs = match boundary {
64 SplineBoundary::Natural => Self::solve_natural(ys, &h, &delta)?,
65 SplineBoundary::Clamped { left, right } => {
66 Self::solve_clamped(n, ys, &h, &delta, left, right)?
67 }
68 };
69
70 Ok(Self {
71 xs: xs.to_vec(),
72 coeffs,
73 extrap,
74 })
75 }
76
77 fn solve_natural(ys: &[T], h: &[T], delta: &[T]) -> Result<Vec<(T, T, T, T)>> {
78 let n = ys.len();
79 let two = T::from_f64(2.0);
80 let six = T::from_f64(6.0);
81
82 if n == 3 {
83 let three = T::from_f64(3.0);
84 let m1 = three * (delta[1] - delta[0]) / (h[0] + h[1]);
85 let m = vec![T::zero(), m1, T::zero()];
86 return Ok(Self::build_coeffs(ys, h, &m, two, six));
87 }
88
89 let size = n - 2;
90 let mut sub = Vec::with_capacity(size - 1);
91 let mut diag = Vec::with_capacity(size);
92 let mut sup = Vec::with_capacity(size - 1);
93 let mut rhs = Vec::with_capacity(size);
94
95 for i in 0..size {
96 let row = i + 1;
97 diag.push(two * (h[row - 1] + h[row]));
98 rhs.push(six * (delta[row] - delta[row - 1]));
99 if i > 0 {
100 sub.push(h[row - 1]);
101 }
102 if i < size - 1 {
103 sup.push(h[row]);
104 }
105 }
106
107 let m_interior = thomas_solve(&sub, &diag, &sup, &rhs)?;
108
109 let mut m = Vec::with_capacity(n);
110 m.push(T::zero());
111 m.extend_from_slice(&m_interior);
112 m.push(T::zero());
113
114 Ok(Self::build_coeffs(ys, h, &m, two, six))
115 }
116
117 fn solve_clamped(
118 n: usize,
119 ys: &[T],
120 h: &[T],
121 delta: &[T],
122 left_deriv: T,
123 right_deriv: T,
124 ) -> Result<Vec<(T, T, T, T)>> {
125 let two = T::from_f64(2.0);
126 let six = T::from_f64(6.0);
127
128 let mut sub = Vec::with_capacity(n - 1);
129 let mut diag = Vec::with_capacity(n);
130 let mut sup = Vec::with_capacity(n - 1);
131 let mut rhs = Vec::with_capacity(n);
132
133 diag.push(two * h[0]);
135 sup.push(h[0]);
136 rhs.push(six * (delta[0] - left_deriv));
137
138 for i in 1..n - 1 {
140 sub.push(h[i - 1]);
141 diag.push(two * (h[i - 1] + h[i]));
142 sup.push(h[i]);
143 rhs.push(six * (delta[i] - delta[i - 1]));
144 }
145
146 sub.push(h[n - 2]);
148 diag.push(two * h[n - 2]);
149 rhs.push(six * (right_deriv - delta[n - 2]));
150
151 let m = thomas_solve(&sub, &diag, &sup, &rhs)?;
152
153 Ok(Self::build_coeffs(ys, h, &m, two, six))
154 }
155
156 fn build_coeffs(ys: &[T], h: &[T], m: &[T], two: T, six: T) -> Vec<(T, T, T, T)> {
157 let nm1 = h.len();
158 let mut coeffs = Vec::with_capacity(nm1);
159 for i in 0..nm1 {
160 let a = ys[i];
161 let b = (ys[i + 1] - ys[i]) / h[i] - h[i] * (two * m[i] + m[i + 1]) / six;
162 let c = m[i] / two;
163 let d = (m[i + 1] - m[i]) / (six * h[i]);
164 coeffs.push((a, b, c, d));
165 }
166 coeffs
167 }
168
169 pub fn eval(&self, x: T) -> Result<T> {
183 let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
184 let dx = xq - self.xs[i];
185 let (a, b, c, d) = self.coeffs[i];
186 Ok(a + dx * (b + dx * (c + dx * d)))
187 }
188
189 pub fn eval_many(&self, xs: &[T]) -> Result<Vec<T>> {
203 xs.iter().map(|&x| self.eval(x)).collect()
204 }
205
206 pub fn derivative(&self, x: T) -> Result<T> {
220 let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
221 let dx = xq - self.xs[i];
222 let (_, b, c, d) = self.coeffs[i];
223 let two = T::from_f64(2.0);
224 let three = T::from_f64(3.0);
225 Ok(b + dx * (two * c + three * d * dx))
226 }
227
228 pub fn second_derivative(&self, x: T) -> Result<T> {
243 let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
244 let dx = xq - self.xs[i];
245 let (_, _, c, d) = self.coeffs[i];
246 let two = T::from_f64(2.0);
247 let six = T::from_f64(6.0);
248 Ok(two * c + six * d * dx)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_cubic_spline_reproduces_data() {
258 let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
259 let ys = [0.0, 1.0, 4.0, 9.0, 16.0];
260 let spline =
261 CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
262 for (i, &x) in xs.iter().enumerate() {
263 let y = spline.eval(x).unwrap();
264 assert!(
265 (y - ys[i]).abs() < 1e-10,
266 "at x={x}: got {y}, expected {}",
267 ys[i]
268 );
269 }
270 }
271
272 #[test]
273 fn test_cubic_spline_natural_boundary() {
274 let spline = CubicSpline::new(
275 &[0.0, 1.0, 2.0, 3.0, 4.0],
276 &[0.0, 0.5, 2.0, 1.5, 0.0],
277 SplineBoundary::Natural,
278 Extrapolate::Error,
279 )
280 .unwrap();
281 let sd_left = spline.second_derivative(0.0).unwrap();
282 let sd_right = spline.second_derivative(4.0).unwrap();
283 assert!(sd_left.abs() < 1e-10, "left 2nd deriv = {sd_left}");
284 assert!(sd_right.abs() < 1e-10, "right 2nd deriv = {sd_right}");
285 }
286
287 #[test]
288 fn test_cubic_spline_clamped() {
289 let spline = CubicSpline::new(
290 &[0.0, 1.0, 2.0, 3.0],
291 &[0.0, 1.0, 4.0, 9.0],
292 SplineBoundary::Clamped {
293 left: 0.0,
294 right: 6.0,
295 },
296 Extrapolate::Error,
297 )
298 .unwrap();
299
300 let d_left = spline.derivative(0.0).unwrap();
301 let d_right = spline.derivative(3.0).unwrap();
302 assert!(d_left.abs() < 1e-8, "left derivative = {d_left}");
303 assert!((d_right - 6.0).abs() < 1e-8, "right derivative = {d_right}");
304 }
305
306 #[test]
307 fn test_cubic_spline_exact_for_cubic() {
308 let xs = [0.0_f64, 1.0, 2.0, 3.0];
309 let ys: Vec<f64> = xs.iter().map(|&x| x * x * x).collect();
310 let spline = CubicSpline::new(
311 &xs,
312 &ys,
313 SplineBoundary::Clamped {
314 left: 0.0,
315 right: 27.0,
316 },
317 Extrapolate::Error,
318 )
319 .unwrap();
320
321 for x in [0.5, 1.0, 1.5, 2.0, 2.5] {
322 let y = spline.eval(x).unwrap();
323 let expected = x * x * x;
324 assert!(
325 (y - expected).abs() < 1e-8,
326 "at x={x}: got {y}, expected {expected}"
327 );
328 }
329 }
330
331 #[test]
332 fn test_cubic_spline_derivative() {
333 let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
334 let ys: Vec<f64> = xs.iter().map(|&x| x * x).collect();
335 let spline =
336 CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
337 let d = spline.derivative(2.0).unwrap();
338 assert!((d - 4.0).abs() < 0.5, "derivative at 2.0 = {d}");
339 }
340
341 #[test]
342 fn test_cubic_spline_continuity() {
343 let xs = [0.0, 1.0, 2.0, 3.0, 4.0];
344 let ys = [0.0, 1.0, 0.0, 1.0, 0.0];
345 let spline =
346 CubicSpline::new(&xs, &ys, SplineBoundary::Natural, Extrapolate::Error).unwrap();
347
348 let eps = 1e-8;
349 for &x in &xs[1..xs.len() - 1] {
350 let left = spline.eval(x - eps).unwrap();
351 let right = spline.eval(x).unwrap();
352 assert!(
353 (left - right).abs() < 1e-5,
354 "discontinuity at x={x}: left={left}, right={right}"
355 );
356 }
357 }
358}