1use crate::error::{InterpolateError, InterpolateResult};
42
43#[derive(Debug, Clone, PartialEq)]
49pub enum ExtrapolationMode {
50 Nearest,
52 Linear,
57 Polynomial(usize),
60 Reflect,
65 Periodic,
67 Fill(f64),
69 Error,
71}
72
73pub trait Interpolate1D {
80 fn interpolate(&self, x: f64) -> f64;
83}
84
85impl<F: Fn(f64) -> f64> Interpolate1D for F {
86 fn interpolate(&self, x: f64) -> f64 {
87 (self)(x)
88 }
89}
90
91pub struct ExtrapolatingInterpolator<I: Interpolate1D> {
101 inner: I,
102 x_min: f64,
103 x_max: f64,
104 mode: ExtrapolationMode,
105}
106
107impl<I: Interpolate1D> ExtrapolatingInterpolator<I> {
108 pub fn new(inner: I, x_min: f64, x_max: f64, mode: ExtrapolationMode) -> Self {
116 assert!(
117 x_max > x_min,
118 "x_max ({x_max}) must be strictly greater than x_min ({x_min})"
119 );
120 Self {
121 inner,
122 x_min,
123 x_max,
124 mode,
125 }
126 }
127
128 pub fn eval(&self, x: f64) -> InterpolateResult<f64> {
130 if x >= self.x_min && x <= self.x_max {
131 return Ok(self.inner.interpolate(x));
132 }
133 let period = self.x_max - self.x_min;
134 match &self.mode {
135 ExtrapolationMode::Nearest => {
136 let clamped = x.clamp(self.x_min, self.x_max);
137 Ok(self.inner.interpolate(clamped))
138 }
139
140 ExtrapolationMode::Linear => {
141 let h = period * 1e-5;
142 if x < self.x_min {
143 let f0 = self.inner.interpolate(self.x_min);
145 let f1 = self.inner.interpolate(self.x_min + h);
146 let slope = (f1 - f0) / h;
147 Ok(f0 + slope * (x - self.x_min))
148 } else {
149 let f0 = self.inner.interpolate(self.x_max);
151 let f1 = self.inner.interpolate(self.x_max - h);
152 let slope = (f0 - f1) / h;
153 Ok(f0 + slope * (x - self.x_max))
154 }
155 }
156
157 ExtrapolationMode::Polynomial(deg) => {
158 let deg = *deg;
159 self.poly_extrapolate(x, deg)
160 }
161
162 ExtrapolationMode::Reflect => {
163 let mapped = if x < self.x_min {
164 2.0 * self.x_min - x
165 } else {
166 2.0 * self.x_max - x
167 };
168 let clamped = mapped.clamp(self.x_min, self.x_max);
170 Ok(self.inner.interpolate(clamped))
171 }
172
173 ExtrapolationMode::Periodic => {
174 let wrapped = wrap_periodic(x, self.x_min, self.x_max);
175 Ok(self.inner.interpolate(wrapped))
176 }
177
178 ExtrapolationMode::Fill(v) => Ok(*v),
179
180 ExtrapolationMode::Error => Err(InterpolateError::OutOfBounds(format!(
181 "x={x:.6} is outside domain [{:.6}, {:.6}]",
182 self.x_min, self.x_max
183 ))),
184 }
185 }
186
187 fn poly_extrapolate(&self, x: f64, deg: usize) -> InterpolateResult<f64> {
192 let n = deg + 1;
193 let h = period_step(self.x_min, self.x_max, n);
194 let nodes: Vec<f64> = if x < self.x_min {
196 (0..n).map(|k| self.x_min + k as f64 * h).collect()
197 } else {
198 (0..n)
199 .map(|k| self.x_max - (n - 1 - k) as f64 * h)
200 .collect()
201 };
202 let ys: Vec<f64> = nodes.iter().map(|&xi| self.inner.interpolate(xi)).collect();
203 Ok(lagrange_eval(&nodes, &ys, x))
204 }
205
206 pub fn x_min(&self) -> f64 {
208 self.x_min
209 }
210
211 pub fn x_max(&self) -> f64 {
213 self.x_max
214 }
215
216 pub fn mode(&self) -> &ExtrapolationMode {
218 &self.mode
219 }
220}
221
222pub struct ExtrapolatingInterpolatorAsymmetric<I: Interpolate1D> {
229 inner: I,
230 x_min: f64,
231 x_max: f64,
232 lower_mode: ExtrapolationMode,
233 upper_mode: ExtrapolationMode,
234}
235
236impl<I: Interpolate1D> ExtrapolatingInterpolatorAsymmetric<I> {
237 pub fn new(
239 inner: I,
240 x_min: f64,
241 x_max: f64,
242 lower_mode: ExtrapolationMode,
243 upper_mode: ExtrapolationMode,
244 ) -> Self {
245 assert!(x_max > x_min);
246 Self {
247 inner,
248 x_min,
249 x_max,
250 lower_mode,
251 upper_mode,
252 }
253 }
254
255 pub fn eval(&self, x: f64) -> InterpolateResult<f64> {
257 if x >= self.x_min && x <= self.x_max {
258 return Ok(self.inner.interpolate(x));
259 }
260 let mode = if x < self.x_min {
261 &self.lower_mode
262 } else {
263 &self.upper_mode
264 };
265 let tmp = ExtrapolatingInterpolator {
267 inner: DummyInner(&self.inner),
268 x_min: self.x_min,
269 x_max: self.x_max,
270 mode: mode.clone(),
271 };
272 tmp.eval(x)
273 }
274}
275
276struct DummyInner<'a, I: Interpolate1D>(&'a I);
278
279impl<'a, I: Interpolate1D> Interpolate1D for DummyInner<'a, I> {
280 fn interpolate(&self, x: f64) -> f64 {
281 self.0.interpolate(x)
282 }
283}
284
285fn wrap_periodic(x: f64, x_min: f64, x_max: f64) -> f64 {
291 let period = x_max - x_min;
292 let shifted = x - x_min;
293 let wrapped = shifted - period * (shifted / period).floor();
294 (x_min + wrapped).clamp(x_min, x_max)
295}
296
297fn period_step(x_min: f64, x_max: f64, n: usize) -> f64 {
299 if n <= 1 {
300 0.0
301 } else {
302 (x_max - x_min) / (n - 1) as f64
303 }
304}
305
306fn lagrange_eval(nodes: &[f64], values: &[f64], x: f64) -> f64 {
309 let n = nodes.len();
310 let mut result = 0.0_f64;
311 for i in 0..n {
312 let mut basis = 1.0_f64;
313 for j in 0..n {
314 if i != j {
315 let denom = nodes[i] - nodes[j];
316 if denom.abs() < 1e-300 {
317 continue;
318 }
319 basis *= (x - nodes[j]) / denom;
320 }
321 }
322 result += values[i] * basis;
323 }
324 result
325}
326
327#[cfg(test)]
332mod tests {
333 use super::*;
334 use std::f64::consts::PI;
335
336 fn linear_unit() -> impl Fn(f64) -> f64 {
338 |x| x
339 }
340
341 fn sin_interp() -> impl Fn(f64) -> f64 {
343 |x: f64| x.sin()
344 }
345
346 #[test]
347 fn test_extrapolation_nearest_below() {
348 let interp =
349 ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Nearest);
350 let val = interp.eval(-0.5).expect("should succeed");
352 assert!((val - 0.0).abs() < 1e-12, "nearest below: {val}");
353 }
354
355 #[test]
356 fn test_extrapolation_nearest_above() {
357 let interp =
358 ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Nearest);
359 let val = interp.eval(2.0).expect("should succeed");
361 assert!((val - 1.0).abs() < 1e-12, "nearest above: {val}");
362 }
363
364 #[test]
365 fn test_extrapolation_linear_below() {
366 let inner = |x: f64| 2.0 * x + 3.0;
368 let interp = ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Linear);
369 let val = interp.eval(-1.0).expect("linear below");
371 assert!((val - 1.0).abs() < 1e-4, "linear extrap below: {val}");
372 }
373
374 #[test]
375 fn test_extrapolation_linear_above() {
376 let inner = |x: f64| 2.0 * x + 3.0;
377 let interp = ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Linear);
378 let val = interp.eval(2.0).expect("linear above");
380 assert!((val - 7.0).abs() < 1e-3, "linear extrap above: {val}");
381 }
382
383 #[test]
384 fn test_extrapolation_fill() {
385 let fill_val = -999.0;
386 let interp = ExtrapolatingInterpolator::new(
387 linear_unit(),
388 0.0,
389 1.0,
390 ExtrapolationMode::Fill(fill_val),
391 );
392 assert_eq!(interp.eval(-5.0).unwrap(), fill_val);
393 assert_eq!(interp.eval(5.0).unwrap(), fill_val);
394 }
395
396 #[test]
397 fn test_extrapolation_fill_nan() {
398 let interp = ExtrapolatingInterpolator::new(
399 linear_unit(),
400 0.0,
401 1.0,
402 ExtrapolationMode::Fill(f64::NAN),
403 );
404 assert!(interp.eval(-1.0).unwrap().is_nan());
405 }
406
407 #[test]
408 fn test_extrapolation_error_mode() {
409 let interp =
410 ExtrapolatingInterpolator::new(linear_unit(), 0.0, 1.0, ExtrapolationMode::Error);
411 assert!(interp.eval(-0.1).is_err(), "Should error below range");
412 assert!(interp.eval(1.1).is_err(), "Should error above range");
413 assert!(interp.eval(0.5).is_ok());
415 }
416
417 #[test]
418 fn test_extrapolation_periodic() {
419 let interp = ExtrapolatingInterpolator::new(
421 sin_interp(),
422 0.0,
423 2.0 * PI,
424 ExtrapolationMode::Periodic,
425 );
426 let y1 = interp.eval(PI).unwrap();
428 let y2 = interp.eval(3.0 * PI).unwrap();
429 assert!(
430 (y1 - y2).abs() < 1e-10,
431 "Periodic: sin(π)={y1} should equal sin(3π)={y2}"
432 );
433 }
434
435 #[test]
436 fn test_extrapolation_periodic_negative() {
437 let interp = ExtrapolatingInterpolator::new(
438 sin_interp(),
439 0.0,
440 2.0 * PI,
441 ExtrapolationMode::Periodic,
442 );
443 let y1 = interp.eval(-PI / 2.0).unwrap();
445 let y2 = interp.eval(3.0 * PI / 2.0).unwrap();
446 assert!((y1 - y2).abs() < 1e-10, "Periodic negative: {y1} vs {y2}");
447 }
448
449 #[test]
450 fn test_extrapolation_reflect_below() {
451 let interp =
453 ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Reflect);
454 let val = interp.eval(-0.3).unwrap();
455 let expected = 0.3_f64 * 0.3;
456 assert!(
457 (val - expected).abs() < 1e-12,
458 "reflect below: {val} vs {expected}"
459 );
460 }
461
462 #[test]
463 fn test_extrapolation_reflect_above() {
464 let interp =
465 ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Reflect);
466 let val = interp.eval(1.4).unwrap();
468 let expected = 0.6_f64 * 0.6;
469 assert!(
470 (val - expected).abs() < 1e-12,
471 "reflect above: {val} vs {expected}"
472 );
473 }
474
475 #[test]
476 fn test_extrapolation_polynomial_linear_exact() {
477 let inner = |x: f64| x + 1.0;
479 let interp =
480 ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Polynomial(1));
481 let val = interp.eval(-0.5).unwrap();
482 let expected = -0.5 + 1.0; assert!(
484 (val - expected).abs() < 1e-8,
485 "poly extrap degree 1: {val} vs {expected}"
486 );
487 }
488
489 #[test]
490 fn test_extrapolation_polynomial_quadratic() {
491 let inner = |x: f64| x * x;
493 let interp =
494 ExtrapolatingInterpolator::new(inner, 0.0, 1.0, ExtrapolationMode::Polynomial(2));
495 let val = interp.eval(2.0).unwrap();
496 assert!(
498 (val - 4.0).abs() < 1e-6,
499 "poly extrap degree 2 above: {val}"
500 );
501 }
502
503 #[test]
504 fn test_inside_domain_uses_inner() {
505 let interp =
506 ExtrapolatingInterpolator::new(|x: f64| x * x, 0.0, 1.0, ExtrapolationMode::Error);
507 assert!((interp.eval(0.5).unwrap() - 0.25).abs() < 1e-15);
508 }
509
510 #[test]
511 fn test_asymmetric_different_modes() {
512 let interp = ExtrapolatingInterpolatorAsymmetric::new(
514 linear_unit(),
515 0.0,
516 1.0,
517 ExtrapolationMode::Nearest,
518 ExtrapolationMode::Error,
519 );
520 let below = interp.eval(-0.5).expect("lower Nearest");
522 assert!((below - 0.0).abs() < 1e-12);
523 let above = interp.eval(1.5);
525 assert!(above.is_err(), "upper Error mode should fail");
526 }
527
528 #[test]
529 fn test_extrapolation_in_range_all_modes() {
530 let modes = vec![
531 ExtrapolationMode::Nearest,
532 ExtrapolationMode::Linear,
533 ExtrapolationMode::Polynomial(2),
534 ExtrapolationMode::Reflect,
535 ExtrapolationMode::Periodic,
536 ExtrapolationMode::Fill(0.0),
537 ExtrapolationMode::Error,
538 ];
539 for mode in modes {
540 let interp = ExtrapolatingInterpolator::new(|x: f64| x, 0.0, 1.0, mode);
541 let val = interp
543 .eval(0.5)
544 .expect("in-range should succeed for any mode");
545 assert!((val - 0.5).abs() < 1e-12, "in-range eval failed: {val}");
546 }
547 }
548
549 #[test]
550 fn test_lagrange_eval_linear() {
551 let nodes = vec![0.0, 1.0];
552 let vals = vec![1.0, 3.0]; let y = lagrange_eval(&nodes, &vals, 2.0);
554 assert!((y - 5.0).abs() < 1e-10, "Lagrange extrapolation: {y}");
555 }
556
557 #[test]
558 fn test_wrap_periodic() {
559 let wrapped_inside = wrap_periodic(5.0, 0.0, 2.0 * std::f64::consts::PI);
561 assert!(
562 wrapped_inside >= 0.0 && wrapped_inside <= 2.0 * std::f64::consts::PI,
563 "inside wrap failed: {wrapped_inside}"
564 );
565 assert!(
566 (wrapped_inside - 5.0).abs() < 1e-12,
567 "inside wrap should be 5.0, got {wrapped_inside}"
568 );
569
570 let wrapped_above = wrap_periodic(7.0, 0.0, 2.0 * std::f64::consts::PI);
572 let expected_above = 7.0 - 2.0 * std::f64::consts::PI;
573 assert!(
574 (wrapped_above - expected_above).abs() < 1e-12,
575 "above wrap: {wrapped_above} vs {expected_above}"
576 );
577
578 let wrapped_below = wrap_periodic(-1.0, 0.0, 2.0 * std::f64::consts::PI);
580 let expected_below = -1.0 + 2.0 * std::f64::consts::PI;
581 assert!(
582 (wrapped_below - expected_below).abs() < 1e-12,
583 "below wrap: {wrapped_below} vs {expected_below}"
584 );
585 }
586}