1pub mod bicubic;
57pub mod bilinear;
58pub mod nd_grid;
59pub mod trilinear;
60
61pub use bicubic::BicubicInterp;
62pub use bilinear::BilinearInterp;
63pub use nd_grid::NdGridInterp;
64pub use trilinear::TrilinearInterp;
65
66use crate::error::{InterpolateError, InterpolateResult};
67use scirs2_core::ndarray::{Array, Array1, ArrayView1, IxDyn};
68use scirs2_core::numeric::{Float, FromPrimitive};
69use std::fmt::{Debug, Display};
70use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
71
72#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum TensorProductMethod {
79 Nearest,
81
82 Multilinear,
84
85 BSpline {
87 degree: usize,
89 },
90}
91
92#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum BoundaryHandling {
95 Error,
97 Clamp,
99 Nan,
101 Extrapolate,
103}
104
105#[derive(Debug, Clone)]
114pub struct TensorProductGridInterpolator<F: Float + FromPrimitive + Debug> {
115 axes: Vec<Array1<F>>,
117 values: Array<F, IxDyn>,
119 method: TensorProductMethod,
121 boundary: BoundaryHandling,
123 ndim: usize,
125 shape: Vec<usize>,
127 bspline_coeffs: Option<Array<F, IxDyn>>,
129}
130
131impl<
132 F: Float
133 + FromPrimitive
134 + Debug
135 + Display
136 + AddAssign
137 + SubAssign
138 + MulAssign
139 + DivAssign
140 + RemAssign
141 + scirs2_core::numeric::Zero
142 + 'static,
143 > TensorProductGridInterpolator<F>
144{
145 pub fn new(
161 axes: Vec<Array1<F>>,
162 values: Array<F, IxDyn>,
163 method: TensorProductMethod,
164 ) -> InterpolateResult<Self> {
165 Self::with_boundary(axes, values, method, BoundaryHandling::Clamp)
166 }
167
168 pub fn with_boundary(
177 axes: Vec<Array1<F>>,
178 values: Array<F, IxDyn>,
179 method: TensorProductMethod,
180 boundary: BoundaryHandling,
181 ) -> InterpolateResult<Self> {
182 let ndim = axes.len();
183
184 if ndim == 0 {
185 return Err(InterpolateError::empty_data(
186 "TensorProductGridInterpolator",
187 ));
188 }
189
190 if ndim != values.ndim() {
191 return Err(InterpolateError::dimension_mismatch(
192 ndim,
193 values.ndim(),
194 "TensorProductGridInterpolator: axes count vs values dimensions",
195 ));
196 }
197
198 let mut shape = Vec::with_capacity(ndim);
199 for (d, axis) in axes.iter().enumerate() {
200 let n = axis.len();
201 if n < 2 {
202 return Err(InterpolateError::insufficient_points(
203 2,
204 n,
205 &format!("TensorProductGridInterpolator axis {}", d),
206 ));
207 }
208
209 for i in 1..n {
211 if axis[i] <= axis[i - 1] {
212 return Err(InterpolateError::invalid_input(format!(
213 "Axis {} is not strictly increasing at index {}: {} <= {}",
214 d,
215 i,
216 axis[i],
217 axis[i - 1]
218 )));
219 }
220 }
221
222 if n != values.shape()[d] {
224 return Err(InterpolateError::shape_mismatch(
225 format!("{}", n),
226 format!("{}", values.shape()[d]),
227 format!("axis {} vs values dimension {}", d, d),
228 ));
229 }
230
231 if let TensorProductMethod::BSpline { degree } = method {
233 if n < degree + 1 {
234 return Err(InterpolateError::insufficient_points(
235 degree + 1,
236 n,
237 &format!(
238 "TensorProductGridInterpolator axis {} for degree-{} B-spline",
239 d, degree
240 ),
241 ));
242 }
243 }
244
245 shape.push(n);
246 }
247
248 let bspline_coeffs = if let TensorProductMethod::BSpline { degree } = method {
250 Some(Self::compute_bspline_coefficients(
251 &axes, &values, &shape, ndim, degree,
252 )?)
253 } else {
254 None
255 };
256
257 Ok(Self {
258 axes,
259 values,
260 method,
261 boundary,
262 ndim,
263 shape,
264 bspline_coeffs,
265 })
266 }
267
268 pub fn evaluate_point(&self, point: &[F]) -> InterpolateResult<F> {
279 if point.len() != self.ndim {
280 return Err(InterpolateError::dimension_mismatch(
281 self.ndim,
282 point.len(),
283 "TensorProductGridInterpolator::evaluate_point",
284 ));
285 }
286
287 match self.method {
288 TensorProductMethod::Nearest => self.nearest_interpolate(point),
289 TensorProductMethod::Multilinear => self.multilinear_interpolate(point),
290 TensorProductMethod::BSpline { degree } => self.bspline_interpolate(point, degree),
291 }
292 }
293
294 pub fn evaluate_point_array(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
296 let pt: Vec<F> = point.iter().copied().collect();
297 self.evaluate_point(&pt)
298 }
299
300 pub fn evaluate_batch(&self, points: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
306 let mut results = Vec::with_capacity(points.len());
307 for pt in points {
308 results.push(self.evaluate_point(pt)?);
309 }
310 Ok(results)
311 }
312
313 pub fn ndim(&self) -> usize {
315 self.ndim
316 }
317
318 pub fn shape(&self) -> &[usize] {
320 &self.shape
321 }
322
323 pub fn axes(&self) -> &[Array1<F>] {
325 &self.axes
326 }
327
328 pub fn values(&self) -> &Array<F, IxDyn> {
330 &self.values
331 }
332
333 fn locate_on_axis(&self, dim: usize, x: F) -> InterpolateResult<(usize, F)> {
340 let axis = &self.axes[dim];
341 let n = axis.len();
342 let lo = axis[0];
343 let hi = axis[n - 1];
344
345 if x < lo || x > hi {
347 match self.boundary {
348 BoundaryHandling::Error => {
349 return Err(InterpolateError::OutOfBounds(format!(
350 "Point coordinate {} in dimension {} is outside grid bounds [{}, {}]",
351 x, dim, lo, hi
352 )));
353 }
354 BoundaryHandling::Nan => {
355 return Ok((0, F::nan()));
356 }
357 BoundaryHandling::Clamp | BoundaryHandling::Extrapolate => {
358 if x < lo {
361 if self.boundary == BoundaryHandling::Clamp {
362 return Ok((0, F::zero()));
363 } else {
364 let h = axis[1] - axis[0];
366 let frac = if h > F::zero() {
367 (x - lo) / h
368 } else {
369 F::zero()
370 };
371 return Ok((0, frac));
372 }
373 } else {
374 if self.boundary == BoundaryHandling::Clamp {
375 return Ok((n - 2, F::one()));
376 } else {
377 let h = axis[n - 1] - axis[n - 2];
378 let frac = if h > F::zero() {
379 (x - axis[n - 2]) / h
380 } else {
381 F::one()
382 };
383 return Ok((n - 2, frac));
384 }
385 }
386 }
387 }
388 }
389
390 let mut lo_idx = 0usize;
392 let mut hi_idx = n - 1;
393
394 while hi_idx - lo_idx > 1 {
395 let mid = (lo_idx + hi_idx) / 2;
396 if x < axis[mid] {
397 hi_idx = mid;
398 } else {
399 lo_idx = mid;
400 }
401 }
402
403 let cell_lo = axis[lo_idx];
405 let cell_hi = axis[hi_idx];
406 let h = cell_hi - cell_lo;
407
408 let frac = if h > F::zero() {
409 (x - cell_lo) / h
410 } else {
411 F::zero()
412 };
413
414 Ok((lo_idx, frac))
415 }
416
417 fn nearest_interpolate(&self, point: &[F]) -> InterpolateResult<F> {
422 let mut idx = Vec::with_capacity(self.ndim);
423
424 for d in 0..self.ndim {
425 let (cell, frac) = self.locate_on_axis(d, point[d])?;
426 if frac.is_nan() {
427 return Ok(F::nan());
428 }
429 let half = F::from_f64(0.5).unwrap_or_else(|| F::one() / (F::one() + F::one()));
431 if frac <= half {
432 idx.push(cell);
433 } else {
434 idx.push((cell + 1).min(self.shape[d] - 1));
435 }
436 }
437
438 Ok(self.values[idx.as_slice()])
439 }
440
441 fn multilinear_interpolate(&self, point: &[F]) -> InterpolateResult<F> {
446 let mut cells = Vec::with_capacity(self.ndim);
447 let mut fracs = Vec::with_capacity(self.ndim);
448
449 for d in 0..self.ndim {
450 let (cell, frac) = self.locate_on_axis(d, point[d])?;
451 if frac.is_nan() {
452 return Ok(F::nan());
453 }
454 cells.push(cell);
455 fracs.push(frac);
456 }
457
458 let n_vertices = 1usize << self.ndim;
461 let mut result = F::zero();
462
463 for vertex in 0..n_vertices {
464 let mut vertex_idx = Vec::with_capacity(self.ndim);
465 let mut weight = F::one();
466
467 for d in 0..self.ndim {
468 let use_upper = (vertex >> d) & 1 == 1;
469 let idx = cells[d] + if use_upper { 1 } else { 0 };
470 vertex_idx.push(idx.min(self.shape[d] - 1));
472
473 weight = weight
474 * if use_upper {
475 fracs[d]
476 } else {
477 F::one() - fracs[d]
478 };
479 }
480
481 result = result + weight * self.values[vertex_idx.as_slice()];
482 }
483
484 Ok(result)
485 }
486
487 fn compute_bspline_coefficients(
497 axes: &[Array1<F>],
498 values: &Array<F, IxDyn>,
499 shape: &[usize],
500 ndim: usize,
501 degree: usize,
502 ) -> InterpolateResult<Array<F, IxDyn>> {
503 let mut coeffs = values.clone();
505
506 for d in 0..ndim {
508 let n = shape[d];
509 let axis = &axes[d];
510
511 let knots = Self::create_clamped_knots(axis, degree);
513 let basis = Self::compute_bspline_basis_matrix(axis, &knots, degree)?;
514
515 let total_fibers: usize = shape
518 .iter()
519 .enumerate()
520 .filter(|&(i, _)| i != d)
521 .map(|(_, &s)| s)
522 .product::<usize>()
523 .max(1);
524
525 let mut multi_idx = vec![0usize; ndim];
527 for _fiber in 0..total_fibers {
528 let mut fiber_vals = Vec::with_capacity(n);
530 for k in 0..n {
531 multi_idx[d] = k;
532 fiber_vals.push(coeffs[multi_idx.as_slice()]);
533 }
534
535 let solved = Self::solve_bspline_system(&basis, &fiber_vals, n)?;
537
538 for k in 0..n {
540 multi_idx[d] = k;
541 *coeffs.get_mut(multi_idx.as_slice()).ok_or_else(|| {
542 InterpolateError::IndexError(format!("Index {:?} out of bounds", multi_idx))
543 })? = solved[k];
544 }
545
546 Self::advance_multi_index(&mut multi_idx, shape, d);
548 }
549 }
550
551 Ok(coeffs)
552 }
553
554 fn advance_multi_index(idx: &mut [usize], shape: &[usize], skip_dim: usize) {
556 for d in 0..idx.len() {
557 if d == skip_dim {
558 continue;
559 }
560 idx[d] += 1;
561 if idx[d] < shape[d] {
562 return;
563 }
564 idx[d] = 0;
565 }
566 }
567
568 fn create_clamped_knots(axis: &Array1<F>, degree: usize) -> Vec<F> {
576 let n = axis.len();
577 let p = degree;
578 let n_knots = n + p + 1;
579 let mut knots = Vec::with_capacity(n_knots);
580
581 for _ in 0..=p {
583 knots.push(axis[0]);
584 }
585
586 if n > p + 1 {
588 for j in 1..(n - p) {
589 let mut sum = F::zero();
590 for i in j..(j + p) {
591 sum = sum + axis[i];
592 }
593 let p_f = F::from_usize(p).unwrap_or_else(|| F::one());
594 knots.push(sum / p_f);
595 }
596 }
597
598 for _ in 0..=p {
600 knots.push(axis[n - 1]);
601 }
602
603 knots
604 }
605
606 fn compute_bspline_basis_matrix(
608 axis: &Array1<F>,
609 knots: &[F],
610 degree: usize,
611 ) -> InterpolateResult<Vec<Vec<F>>> {
612 let n = axis.len();
613 let n_basis = n; let mut matrix = vec![vec![F::zero(); n_basis]; n];
615
616 for i in 0..n {
617 let x = axis[i];
618 for j in 0..n_basis {
619 matrix[i][j] = Self::bspline_basis_robust(j, degree, x, knots, n_basis);
620 }
621 }
622
623 Ok(matrix)
624 }
625
626 fn bspline_basis_robust(i: usize, k: usize, x: F, knots: &[F], n_basis: usize) -> F {
629 if k == 0 {
630 if i + 1 >= knots.len() {
631 return F::zero();
632 }
633 if x >= knots[i] && x < knots[i + 1] {
635 return F::one();
636 }
637 if i == n_basis - 1 && x == knots[i + 1] {
640 return F::one();
641 }
642 return F::zero();
643 }
644
645 let mut result = F::zero();
646
647 if i + k < knots.len() {
649 let denom = knots[i + k] - knots[i];
650 if denom > F::zero() {
651 let left = Self::bspline_basis_robust(i, k - 1, x, knots, n_basis);
652 result = result + (x - knots[i]) / denom * left;
653 }
654 }
655
656 if i + k + 1 < knots.len() {
658 let denom = knots[i + k + 1] - knots[i + 1];
659 if denom > F::zero() {
660 let right = Self::bspline_basis_robust(i + 1, k - 1, x, knots, n_basis);
661 result = result + (knots[i + k + 1] - x) / denom * right;
662 }
663 }
664
665 result
666 }
667
668 fn solve_bspline_system(matrix: &[Vec<F>], rhs: &[F], n: usize) -> InterpolateResult<Vec<F>> {
670 let mut aug: Vec<Vec<F>> = Vec::with_capacity(n);
672 for i in 0..n {
673 let mut row = Vec::with_capacity(n + 1);
674 for j in 0..n {
675 row.push(matrix[i][j]);
676 }
677 row.push(rhs[i]);
678 aug.push(row);
679 }
680
681 let eps = F::from_f64(1e-14).unwrap_or_else(|| F::epsilon());
682
683 for col in 0..n {
685 let mut max_val = aug[col][col].abs();
687 let mut max_row = col;
688 for row in (col + 1)..n {
689 let val = aug[row][col].abs();
690 if val > max_val {
691 max_val = val;
692 max_row = row;
693 }
694 }
695
696 if max_val < eps {
697 return Err(InterpolateError::numerical_error(
698 "Singular B-spline basis matrix; cannot compute coefficients",
699 ));
700 }
701
702 if max_row != col {
704 aug.swap(col, max_row);
705 }
706
707 let pivot = aug[col][col];
709 for row in (col + 1)..n {
710 let factor = aug[row][col] / pivot;
711 for j in col..=n {
712 let val = aug[col][j];
713 aug[row][j] = aug[row][j] - factor * val;
714 }
715 }
716 }
717
718 let mut result = vec![F::zero(); n];
720 for i in (0..n).rev() {
721 let mut sum = aug[i][n];
722 for j in (i + 1)..n {
723 sum = sum - aug[i][j] * result[j];
724 }
725 let diag = aug[i][i];
726 if diag.abs() < eps {
727 return Err(InterpolateError::numerical_error(
728 "Zero diagonal in back substitution",
729 ));
730 }
731 result[i] = sum / diag;
732 }
733
734 Ok(result)
735 }
736
737 fn bspline_interpolate(&self, point: &[F], degree: usize) -> InterpolateResult<F> {
739 let coeffs = self.bspline_coeffs.as_ref().ok_or_else(|| {
740 InterpolateError::InvalidState("B-spline coefficients not computed".to_string())
741 })?;
742
743 let mut basis_vals: Vec<Vec<(usize, F)>> = Vec::with_capacity(self.ndim);
745
746 for d in 0..self.ndim {
747 let axis = &self.axes[d];
748 let knots = Self::create_clamped_knots(axis, degree);
749 let n = axis.len();
750
751 let x =
753 match self.boundary {
754 BoundaryHandling::Error => {
755 if point[d] < axis[0] || point[d] > axis[n - 1] {
756 return Err(InterpolateError::OutOfBounds(format!(
757 "Point coordinate {} in dimension {} is outside grid bounds [{}, {}]",
758 point[d], d, axis[0], axis[n - 1]
759 )));
760 }
761 point[d]
762 }
763 BoundaryHandling::Nan => {
764 if point[d] < axis[0] || point[d] > axis[n - 1] {
765 return Ok(F::nan());
766 }
767 point[d]
768 }
769 BoundaryHandling::Clamp => point[d].max(axis[0]).min(axis[n - 1]),
770 BoundaryHandling::Extrapolate => point[d],
771 };
772
773 let mut vals = Vec::new();
775 for j in 0..n {
776 let b = Self::bspline_basis_robust(j, degree, x, &knots, n);
777 if b.abs() > F::epsilon() {
778 vals.push((j, b));
779 }
780 }
781
782 if vals.is_empty() {
784 let mut nearest = 0;
786 let mut min_d = (x - axis[0]).abs();
787 for j in 1..n {
788 let dist = (x - axis[j]).abs();
789 if dist < min_d {
790 min_d = dist;
791 nearest = j;
792 }
793 }
794 vals.push((nearest, F::one()));
795 }
796
797 basis_vals.push(vals);
798 }
799
800 self.tensor_product_sum(coeffs, &basis_vals, 0, &mut vec![0usize; self.ndim])
804 }
805
806 fn tensor_product_sum(
808 &self,
809 coeffs: &Array<F, IxDyn>,
810 basis_vals: &[Vec<(usize, F)>],
811 dim: usize,
812 idx: &mut Vec<usize>,
813 ) -> InterpolateResult<F> {
814 if dim == self.ndim {
815 return Ok(coeffs[idx.as_slice()]);
817 }
818
819 let mut sum = F::zero();
820 for &(j, b) in &basis_vals[dim] {
821 idx[dim] = j;
822 let inner = self.tensor_product_sum(coeffs, basis_vals, dim + 1, idx)?;
823 sum = sum + b * inner;
824 }
825
826 Ok(sum)
827 }
828}
829
830pub fn make_multilinear_interpolator<
859 F: Float
860 + FromPrimitive
861 + Debug
862 + Display
863 + AddAssign
864 + SubAssign
865 + MulAssign
866 + DivAssign
867 + RemAssign
868 + scirs2_core::numeric::Zero
869 + 'static,
870>(
871 axes: Vec<Array1<F>>,
872 values: Array<F, IxDyn>,
873) -> InterpolateResult<TensorProductGridInterpolator<F>> {
874 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
875}
876
877pub fn make_tensor_bspline_interpolator<
885 F: Float
886 + FromPrimitive
887 + Debug
888 + Display
889 + AddAssign
890 + SubAssign
891 + MulAssign
892 + DivAssign
893 + RemAssign
894 + scirs2_core::numeric::Zero
895 + 'static,
896>(
897 axes: Vec<Array1<F>>,
898 values: Array<F, IxDyn>,
899 degree: usize,
900) -> InterpolateResult<TensorProductGridInterpolator<F>> {
901 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::BSpline { degree })
902}
903
904#[cfg(test)]
909mod tests {
910 use super::*;
911 use scirs2_core::ndarray::{Array, Array1, IxDyn};
912
913 fn make_2d_linear_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
914 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
916 let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
917 let mut values = Array::zeros(IxDyn(&[4, 3]));
918 for i in 0..4 {
919 for j in 0..3 {
920 values[[i, j].as_slice()] = x[i] + 2.0 * y[j];
921 }
922 }
923 (vec![x, y], values)
924 }
925
926 fn make_2d_nonuniform_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
927 let x = Array1::from_vec(vec![0.0, 0.5, 1.0, 2.0, 4.0]);
929 let y = Array1::from_vec(vec![0.0, 0.1, 1.0, 3.0]);
930 let mut values = Array::zeros(IxDyn(&[5, 4]));
931 for i in 0..5 {
932 for j in 0..4 {
933 values[[i, j].as_slice()] = x[i] * y[j];
934 }
935 }
936 (vec![x, y], values)
937 }
938
939 fn make_3d_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
940 let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
942 let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
943 let z = Array1::from_vec(vec![0.0, 1.0, 2.0]);
944 let mut values = Array::zeros(IxDyn(&[3, 3, 3]));
945 for i in 0..3 {
946 for j in 0..3 {
947 for k in 0..3 {
948 values[[i, j, k].as_slice()] = x[i] + y[j] + z[k];
949 }
950 }
951 }
952 (vec![x, y, z], values)
953 }
954
955 #[test]
958 fn test_multilinear_at_grid_points() {
959 let (axes, values) = make_2d_linear_grid();
960 let interp = TensorProductGridInterpolator::new(
961 axes.clone(),
962 values.clone(),
963 TensorProductMethod::Multilinear,
964 )
965 .expect("valid");
966
967 for i in 0..4 {
969 for j in 0..3 {
970 let result = interp
971 .evaluate_point(&[axes[0][i], axes[1][j]])
972 .expect("valid");
973 let expected = values[[i, j].as_slice()];
974 assert!(
975 (result - expected).abs() < 1e-12,
976 "At grid point ({}, {}): expected {}, got {}",
977 i,
978 j,
979 expected,
980 result
981 );
982 }
983 }
984 }
985
986 #[test]
987 fn test_multilinear_reproduces_linear_function() {
988 let (axes, values) = make_2d_linear_grid();
989 let interp =
990 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
991 .expect("valid");
992
993 let test_points = vec![(0.5, 0.5), (1.5, 1.5), (2.5, 1.0), (0.3, 1.7)];
995 for (x, y) in test_points {
996 let result = interp.evaluate_point(&[x, y]).expect("valid");
997 let expected = x + 2.0 * y;
998 assert!(
999 (result - expected).abs() < 1e-10,
1000 "Multilinear at ({}, {}): expected {}, got {}",
1001 x,
1002 y,
1003 expected,
1004 result
1005 );
1006 }
1007 }
1008
1009 #[test]
1010 fn test_multilinear_nonuniform_grid() {
1011 let (axes, values) = make_2d_nonuniform_grid();
1012 let interp =
1013 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1014 .expect("valid");
1015
1016 let result = interp.evaluate_point(&[0.75, 0.55]).expect("valid");
1019 assert!(
1025 (result - 0.4125).abs() < 1e-10,
1026 "Nonuniform bilinear: expected 0.4125, got {}",
1027 result
1028 );
1029 }
1030
1031 #[test]
1032 fn test_multilinear_3d() {
1033 let (axes, values) = make_3d_grid();
1034 let interp =
1035 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1036 .expect("valid");
1037
1038 let result = interp.evaluate_point(&[0.5, 1.5, 0.5]).expect("valid");
1040 let expected = 0.5 + 1.5 + 0.5;
1041 assert!(
1042 (result - expected).abs() < 1e-10,
1043 "3D multilinear at (0.5, 1.5, 0.5): expected {}, got {}",
1044 expected,
1045 result
1046 );
1047 }
1048
1049 #[test]
1052 fn test_nearest_at_grid_points() {
1053 let (axes, values) = make_2d_linear_grid();
1054 let interp = TensorProductGridInterpolator::new(
1055 axes.clone(),
1056 values.clone(),
1057 TensorProductMethod::Nearest,
1058 )
1059 .expect("valid");
1060
1061 for i in 0..4 {
1062 for j in 0..3 {
1063 let result = interp
1064 .evaluate_point(&[axes[0][i], axes[1][j]])
1065 .expect("valid");
1066 let expected = values[[i, j].as_slice()];
1067 assert!(
1068 (result - expected).abs() < 1e-12,
1069 "Nearest at grid point ({}, {}): expected {}, got {}",
1070 i,
1071 j,
1072 expected,
1073 result
1074 );
1075 }
1076 }
1077 }
1078
1079 #[test]
1080 fn test_nearest_between_points() {
1081 let (axes, values) = make_2d_linear_grid();
1082 let interp = TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Nearest)
1083 .expect("valid");
1084
1085 let result = interp.evaluate_point(&[0.3, 0.3]).expect("valid");
1087 assert!(
1088 (result - 0.0).abs() < 1e-10,
1089 "Nearest at (0.3, 0.3): expected 0.0, got {}",
1090 result
1091 );
1092
1093 let result = interp.evaluate_point(&[2.7, 1.7]).expect("valid");
1095 assert!(
1096 (result - 7.0).abs() < 1e-10,
1097 "Nearest at (2.7, 1.7): expected 7.0, got {}",
1098 result
1099 );
1100 }
1101
1102 #[test]
1105 fn test_bspline_linear_at_grid_points() {
1106 let (axes, values) = make_2d_linear_grid();
1107 let interp = TensorProductGridInterpolator::new(
1108 axes.clone(),
1109 values.clone(),
1110 TensorProductMethod::BSpline { degree: 1 },
1111 )
1112 .expect("valid");
1113
1114 for i in 0..4 {
1116 for j in 0..3 {
1117 let result = interp
1118 .evaluate_point(&[axes[0][i], axes[1][j]])
1119 .expect("valid");
1120 let expected = values[[i, j].as_slice()];
1121 assert!(
1122 (result - expected).abs() < 1e-8,
1123 "BSpline(1) at grid ({}, {}): expected {}, got {}",
1124 i,
1125 j,
1126 expected,
1127 result
1128 );
1129 }
1130 }
1131 }
1132
1133 fn make_2d_linear_grid_4x4() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
1134 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1136 let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1137 let mut values = Array::zeros(IxDyn(&[4, 4]));
1138 for i in 0..4 {
1139 for j in 0..4 {
1140 values[[i, j].as_slice()] = x[i] + 2.0 * y[j];
1141 }
1142 }
1143 (vec![x, y], values)
1144 }
1145
1146 #[test]
1147 fn test_bspline_cubic_at_grid_points() {
1148 let (axes, values) = make_2d_linear_grid_4x4();
1149 let interp = TensorProductGridInterpolator::new(
1150 axes.clone(),
1151 values.clone(),
1152 TensorProductMethod::BSpline { degree: 3 },
1153 )
1154 .expect("valid");
1155
1156 for i in 0..4 {
1158 for j in 0..4 {
1159 let result = interp
1160 .evaluate_point(&[axes[0][i], axes[1][j]])
1161 .expect("valid");
1162 let expected = values[[i, j].as_slice()];
1163 assert!(
1164 (result - expected).abs() < 1e-6,
1165 "BSpline(3) at grid ({}, {}): expected {}, got {}",
1166 i,
1167 j,
1168 expected,
1169 result
1170 );
1171 }
1172 }
1173 }
1174
1175 #[test]
1176 fn test_bspline_cubic_interior_points() {
1177 let (axes, values) = make_2d_linear_grid_4x4();
1178 let interp = TensorProductGridInterpolator::new(
1179 axes,
1180 values,
1181 TensorProductMethod::BSpline { degree: 3 },
1182 )
1183 .expect("valid");
1184
1185 let result = interp.evaluate_point(&[1.5, 0.5]).expect("valid");
1187 let expected = 1.5 + 2.0 * 0.5;
1188 assert!(
1189 (result - expected).abs() < 0.5,
1190 "BSpline(3) at (1.5, 0.5): expected {}, got {}",
1191 expected,
1192 result
1193 );
1194 }
1195
1196 #[test]
1199 fn test_boundary_clamp() {
1200 let (axes, values) = make_2d_linear_grid();
1201 let interp = TensorProductGridInterpolator::with_boundary(
1202 axes,
1203 values,
1204 TensorProductMethod::Multilinear,
1205 BoundaryHandling::Clamp,
1206 )
1207 .expect("valid");
1208
1209 let result = interp.evaluate_point(&[-1.0, -1.0]).expect("valid");
1211 assert!(
1213 (result - 0.0).abs() < 1e-10,
1214 "Clamped at (-1,-1): expected 0.0, got {}",
1215 result
1216 );
1217
1218 let result = interp.evaluate_point(&[10.0, 10.0]).expect("valid");
1219 assert!(
1221 (result - 7.0).abs() < 1e-10,
1222 "Clamped at (10,10): expected 7.0, got {}",
1223 result
1224 );
1225 }
1226
1227 #[test]
1228 fn test_boundary_error() {
1229 let (axes, values) = make_2d_linear_grid();
1230 let interp = TensorProductGridInterpolator::with_boundary(
1231 axes,
1232 values,
1233 TensorProductMethod::Multilinear,
1234 BoundaryHandling::Error,
1235 )
1236 .expect("valid");
1237
1238 let result = interp.evaluate_point(&[-1.0, 0.5]);
1239 assert!(result.is_err(), "Should error for out-of-bounds point");
1240 }
1241
1242 #[test]
1243 fn test_boundary_nan() {
1244 let (axes, values) = make_2d_linear_grid();
1245 let interp = TensorProductGridInterpolator::with_boundary(
1246 axes,
1247 values,
1248 TensorProductMethod::Multilinear,
1249 BoundaryHandling::Nan,
1250 )
1251 .expect("valid");
1252
1253 let result = interp.evaluate_point(&[-1.0, 0.5]).expect("valid");
1254 assert!(result.is_nan(), "Should return NaN for out-of-bounds point");
1255 }
1256
1257 #[test]
1258 fn test_boundary_extrapolate() {
1259 let (axes, values) = make_2d_linear_grid();
1260 let interp = TensorProductGridInterpolator::with_boundary(
1261 axes,
1262 values,
1263 TensorProductMethod::Multilinear,
1264 BoundaryHandling::Extrapolate,
1265 )
1266 .expect("valid");
1267
1268 let result = interp.evaluate_point(&[-0.5, 0.5]).expect("valid");
1270 assert!(
1272 (result - 0.5).abs() < 1e-10,
1273 "Extrapolated at (-0.5, 0.5): expected 0.5, got {}",
1274 result
1275 );
1276 }
1277
1278 #[test]
1281 fn test_batch_evaluation() {
1282 let (axes, values) = make_2d_linear_grid();
1283 let interp =
1284 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1285 .expect("valid");
1286
1287 let points = vec![vec![0.5, 0.5], vec![1.5, 1.0], vec![2.0, 1.5]];
1288 let results = interp.evaluate_batch(&points).expect("valid");
1289
1290 assert_eq!(results.len(), 3);
1291 assert!((results[0] - (0.5 + 1.0)).abs() < 1e-10);
1292 assert!((results[1] - (1.5 + 2.0)).abs() < 1e-10);
1293 assert!((results[2] - (2.0 + 3.0)).abs() < 1e-10);
1294 }
1295
1296 #[test]
1299 fn test_empty_axes_rejected() {
1300 let axes: Vec<Array1<f64>> = vec![];
1301 let values = Array::zeros(IxDyn(&[]));
1302 let result =
1303 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear);
1304 assert!(result.is_err(), "Empty axes should be rejected");
1305 }
1306
1307 #[test]
1308 fn test_too_few_points_rejected() {
1309 let x = Array1::from_vec(vec![0.0]); let values = Array::zeros(IxDyn(&[1]));
1311 let result =
1312 TensorProductGridInterpolator::new(vec![x], values, TensorProductMethod::Multilinear);
1313 assert!(result.is_err(), "Single-point axis should be rejected");
1314 }
1315
1316 #[test]
1317 fn test_nonsorted_axis_rejected() {
1318 let x = Array1::from_vec(vec![0.0, 2.0, 1.0]); let y = Array1::from_vec(vec![0.0, 1.0]);
1320 let values = Array::zeros(IxDyn(&[3, 2]));
1321 let result = TensorProductGridInterpolator::new(
1322 vec![x, y],
1323 values,
1324 TensorProductMethod::Multilinear,
1325 );
1326 assert!(result.is_err(), "Non-sorted axis should be rejected");
1327 }
1328
1329 #[test]
1330 fn test_shape_mismatch_rejected() {
1331 let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
1332 let y = Array1::from_vec(vec![0.0, 1.0]);
1333 let values = Array::zeros(IxDyn(&[3, 3])); let result = TensorProductGridInterpolator::new(
1335 vec![x, y],
1336 values,
1337 TensorProductMethod::Multilinear,
1338 );
1339 assert!(result.is_err(), "Shape mismatch should be rejected");
1340 }
1341
1342 #[test]
1343 fn test_wrong_dimension_query_rejected() {
1344 let (axes, values) = make_2d_linear_grid();
1345 let interp =
1346 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1347 .expect("valid");
1348
1349 let result = interp.evaluate_point(&[1.0]); assert!(result.is_err(), "Wrong dimension query should be rejected");
1351 }
1352
1353 #[test]
1356 fn test_accessors() {
1357 let (axes, values) = make_2d_linear_grid();
1358 let interp =
1359 TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1360 .expect("valid");
1361
1362 assert_eq!(interp.ndim(), 2);
1363 assert_eq!(interp.shape(), &[4, 3]);
1364 assert_eq!(interp.axes().len(), 2);
1365 }
1366
1367 #[test]
1370 fn test_make_multilinear_interpolator() {
1371 let (axes, values) = make_2d_linear_grid();
1372 let interp = make_multilinear_interpolator(axes, values).expect("valid");
1373 let result = interp.evaluate_point(&[1.0, 1.0]).expect("valid");
1374 assert!((result - 3.0).abs() < 1e-10);
1375 }
1376
1377 #[test]
1378 fn test_make_tensor_bspline_interpolator() {
1379 let (axes, values) = make_2d_linear_grid();
1380 let interp = make_tensor_bspline_interpolator(axes, values, 1).expect("valid");
1381 let result = interp.evaluate_point(&[1.0, 1.0]).expect("valid");
1382 assert!(
1383 (result - 3.0).abs() < 1e-6,
1384 "BSpline at (1,1): expected 3.0, got {}",
1385 result
1386 );
1387 }
1388
1389 #[test]
1392 fn test_multilinear_convergence_quadratic() {
1393 let test_point = [0.37_f64, 0.63];
1397 let exact_value = 0.37 * 0.37 + 0.63 * 0.63;
1398
1399 let mut errors = Vec::new();
1400 for &n in &[5, 10, 20, 40] {
1401 let x = Array1::linspace(0.0, 1.0, n);
1402 let y = Array1::linspace(0.0, 1.0, n);
1403 let mut values = Array::zeros(IxDyn(&[n, n]));
1404 for i in 0..n {
1405 for j in 0..n {
1406 values[[i, j].as_slice()] = x[i] * x[i] + y[j] * y[j];
1407 }
1408 }
1409
1410 let interp = TensorProductGridInterpolator::new(
1411 vec![x, y],
1412 values,
1413 TensorProductMethod::Multilinear,
1414 )
1415 .expect("valid");
1416
1417 let result = interp.evaluate_point(&test_point).expect("valid");
1418 let error = (result - exact_value).abs();
1419 errors.push(error);
1420 }
1421
1422 assert!(
1424 errors[errors.len() - 1] < errors[0],
1425 "Error should decrease: first={}, last={}",
1426 errors[0],
1427 errors[errors.len() - 1]
1428 );
1429
1430 assert!(
1431 errors[errors.len() - 1] < 0.01,
1432 "Multilinear should converge to the exact value: final error = {}",
1433 errors[errors.len() - 1]
1434 );
1435 }
1436
1437 #[test]
1440 fn test_1d_multilinear() {
1441 let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1442 let mut values = Array::zeros(IxDyn(&[4]));
1443 for i in 0..4 {
1444 values[[i].as_slice()] = (i as f64) * (i as f64); }
1446
1447 let interp =
1448 TensorProductGridInterpolator::new(vec![x], values, TensorProductMethod::Multilinear)
1449 .expect("valid");
1450
1451 let result = interp.evaluate_point(&[0.5]).expect("valid");
1453 assert!(
1454 (result - 0.5).abs() < 1e-10,
1455 "1D multilinear at 0.5: expected 0.5, got {}",
1456 result
1457 );
1458 }
1459
1460 #[test]
1463 fn test_bspline_insufficient_points_for_degree() {
1464 let x = Array1::from_vec(vec![0.0, 1.0, 2.0]); let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
1467 let values = Array::zeros(IxDyn(&[3, 3]));
1468 let result = TensorProductGridInterpolator::new(
1469 vec![x, y],
1470 values,
1471 TensorProductMethod::BSpline { degree: 3 },
1472 );
1473 assert!(
1474 result.is_err(),
1475 "Should reject degree 3 with only 3 points per axis"
1476 );
1477 }
1478}