scivex_optim/interpolate/
linear.rs1use scivex_core::Float;
4
5use crate::error::{OptimError, Result};
6
7use super::{Extrapolate, find_interval, validate_finite, validate_sorted};
8
9#[derive(Debug, Clone)]
14pub struct Linear1d<T: Float> {
15 xs: Vec<T>,
16 ys: Vec<T>,
17 extrap: Extrapolate,
18}
19
20impl<T: Float> Linear1d<T> {
21 pub fn new(xs: &[T], ys: &[T], extrap: Extrapolate) -> Result<Self> {
40 if xs.len() != ys.len() {
41 return Err(OptimError::InvalidParameter {
42 name: "ys",
43 reason: "xs and ys must have the same length",
44 });
45 }
46 validate_sorted(xs, 2)?;
47 validate_finite(xs, "xs")?;
48 validate_finite(ys, "ys")?;
49
50 Ok(Self {
51 xs: xs.to_vec(),
52 ys: ys.to_vec(),
53 extrap,
54 })
55 }
56
57 pub fn eval(&self, x: T) -> Result<T> {
68 let (i, xq) = find_interval(&self.xs, x, self.extrap)?;
69 let dx = self.xs[i + 1] - self.xs[i];
70 let t = (xq - self.xs[i]) / dx;
71 Ok(self.ys[i] * (T::one() - t) + self.ys[i + 1] * t)
72 }
73
74 pub fn eval_many(&self, xs: &[T]) -> Result<Vec<T>> {
86 xs.iter().map(|&x| self.eval(x)).collect()
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_linear_exact() {
96 let xs = [0.0, 1.0, 2.0, 3.0];
98 let ys = [0.0, 2.0, 4.0, 6.0];
99 let interp = Linear1d::new(&xs, &ys, Extrapolate::Error).unwrap();
100
101 for &x in &[0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] {
102 let y = interp.eval(x).unwrap();
103 assert!((y - 2.0 * x).abs() < 1e-12, "x={x}, y={y}");
104 }
105 }
106
107 #[test]
108 fn test_linear_midpoint() {
109 let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
110 let y = interp.eval(1.5).unwrap();
111 assert!((y - 2.5).abs() < 1e-12);
112 }
113
114 #[test]
115 fn test_linear_eval_many() {
116 let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
117 let result = interp.eval_many(&[0.0, 1.0, 2.0]).unwrap();
118 assert!(result[0].abs() < 1e-12);
119 assert!((result[1] - 1.0).abs() < 1e-12);
120 assert!((result[2] - 4.0).abs() < 1e-12);
121 }
122
123 #[test]
124 fn test_linear_out_of_range_error() {
125 let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[0.0, 1.0, 4.0], Extrapolate::Error).unwrap();
126 assert!(interp.eval(-0.5).is_err());
127 assert!(interp.eval(2.5).is_err());
128 }
129
130 #[test]
131 fn test_linear_clamp() {
132 let interp = Linear1d::new(&[0.0, 1.0, 2.0], &[1.0, 3.0, 7.0], Extrapolate::Clamp).unwrap();
133 assert!((interp.eval(-1.0).unwrap() - 1.0).abs() < 1e-12);
134 assert!((interp.eval(5.0).unwrap() - 7.0).abs() < 1e-12);
135 }
136
137 #[test]
138 fn test_linear_unsorted_error() {
139 assert!(Linear1d::new(&[0.0, 2.0, 1.0], &[0.0, 1.0, 2.0], Extrapolate::Error).is_err());
140 }
141
142 #[test]
143 fn test_linear_f32() {
144 let interp = Linear1d::new(
145 &[0.0_f32, 1.0, 2.0],
146 &[0.0_f32, 1.0, 4.0],
147 Extrapolate::Error,
148 )
149 .unwrap();
150 let y = interp.eval(1.5_f32).unwrap();
151 assert!((y - 2.5_f32).abs() < 1e-6);
152 }
153}