Skip to main content

scirs2_interpolate/tensor_product/
mod.rs

1//! Tensor product grid interpolation
2//!
3//! This module provides N-dimensional interpolation on structured (tensor product)
4//! grids, where data points lie on a rectilinear grid defined by the Cartesian
5//! product of 1D coordinate arrays.
6//!
7//! ## Methods
8//!
9//! - **Multilinear interpolation**: N-dimensional generalization of bilinear
10//!   interpolation. Piecewise linear in each dimension. C0 continuous.
11//!
12//! - **Tensor product B-spline interpolation**: Uses B-splines along each
13//!   dimension for smooth interpolation. Configurable spline degree.
14//!   C^(k-1) continuous for degree-k splines.
15//!
16//! - **Nearest grid point**: Returns the value at the nearest grid point.
17//!   Piecewise constant. Fast evaluation.
18//!
19//! ## Grid types
20//!
21//! All methods support non-uniform grid spacing. The grid is defined by
22//! N one-dimensional coordinate arrays, one per dimension. Grid points
23//! must be strictly increasing along each axis.
24//!
25//! ## Examples
26//!
27//! ```rust
28//! use scirs2_core::ndarray::{Array, Array1, IxDyn};
29//! use scirs2_interpolate::tensor_product::{
30//!     TensorProductGridInterpolator, TensorProductMethod,
31//! };
32//!
33//! // Create a 2D grid with non-uniform spacing
34//! let x = Array1::from_vec(vec![0.0_f64, 0.5, 1.0, 2.0]);
35//! let y = Array1::from_vec(vec![0.0_f64, 1.0, 3.0]);
36//!
37//! // Values: z = x * y
38//! let mut values: scirs2_core::ndarray::Array<f64, IxDyn> = Array::zeros(IxDyn(&[4, 3]));
39//! for i in 0..4 {
40//!     for j in 0..3 {
41//!         values[[i, j].as_slice()] = x[i] * y[j];
42//!     }
43//! }
44//!
45//! let interp = TensorProductGridInterpolator::new(
46//!     vec![x, y],
47//!     values,
48//!     TensorProductMethod::Multilinear,
49//! ).expect("valid interpolator");
50//!
51//! let result = interp.evaluate_point(&[0.75_f64, 2.0]).expect("valid");
52//! // At (0.75, 2.0): 0.75 * 2.0 = 1.5
53//! assert!((result - 1.5_f64).abs() < 0.01);
54//! ```
55
56pub mod bicubic;
57pub mod bilinear;
58pub mod nd_grid;
59pub mod trilinear;
60
61pub use bicubic::BicubicInterp;
62pub use bilinear::BilinearInterp;
63pub use nd_grid::NdGridInterp;
64pub use trilinear::TrilinearInterp;
65
66use crate::error::{InterpolateError, InterpolateResult};
67use scirs2_core::ndarray::{Array, Array1, ArrayView1, IxDyn};
68use scirs2_core::numeric::{Float, FromPrimitive};
69use std::fmt::{Debug, Display};
70use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
71
72// ---------------------------------------------------------------------------
73// Interpolation method
74// ---------------------------------------------------------------------------
75
76/// Interpolation method for tensor product grids
77#[derive(Debug, Clone, Copy, PartialEq)]
78pub enum TensorProductMethod {
79    /// Nearest grid point interpolation (piecewise constant)
80    Nearest,
81
82    /// Multilinear interpolation (N-dimensional extension of bilinear)
83    Multilinear,
84
85    /// Tensor product B-spline interpolation with specified degree
86    BSpline {
87        /// Degree of the B-spline (1 = linear, 3 = cubic)
88        degree: usize,
89    },
90}
91
92/// Boundary handling for tensor product interpolation
93#[derive(Debug, Clone, Copy, PartialEq)]
94pub enum BoundaryHandling {
95    /// Return an error for points outside the grid
96    Error,
97    /// Clamp to the grid boundary
98    Clamp,
99    /// Return NaN for points outside the grid
100    Nan,
101    /// Extrapolate beyond the grid using the boundary cell
102    Extrapolate,
103}
104
105// ---------------------------------------------------------------------------
106// Tensor product grid interpolator
107// ---------------------------------------------------------------------------
108
109/// N-dimensional interpolator on a tensor product (rectilinear) grid
110///
111/// The grid is defined by the Cartesian product of 1D coordinate arrays.
112/// Grid spacing may be non-uniform along each axis.
113#[derive(Debug, Clone)]
114pub struct TensorProductGridInterpolator<F: Float + FromPrimitive + Debug> {
115    /// 1D coordinate arrays for each dimension
116    axes: Vec<Array1<F>>,
117    /// Values on the grid, shape matching the axes lengths
118    values: Array<F, IxDyn>,
119    /// Interpolation method
120    method: TensorProductMethod,
121    /// Boundary handling mode
122    boundary: BoundaryHandling,
123    /// Number of dimensions
124    ndim: usize,
125    /// Grid shape (length along each axis)
126    shape: Vec<usize>,
127    /// Precomputed 1D B-spline coefficients per axis (for BSpline method)
128    bspline_coeffs: Option<Array<F, IxDyn>>,
129}
130
131impl<
132        F: Float
133            + FromPrimitive
134            + Debug
135            + Display
136            + AddAssign
137            + SubAssign
138            + MulAssign
139            + DivAssign
140            + RemAssign
141            + scirs2_core::numeric::Zero
142            + 'static,
143    > TensorProductGridInterpolator<F>
144{
145    /// Create a new tensor product grid interpolator
146    ///
147    /// # Arguments
148    ///
149    /// * `axes` - 1D coordinate arrays, one per dimension. Must be strictly increasing.
150    /// * `values` - N-dimensional array of values on the grid.
151    /// * `method` - Interpolation method.
152    ///
153    /// # Errors
154    ///
155    /// Returns an error if:
156    /// - `axes` is empty
157    /// - Any axis has fewer than 2 points (or fewer than degree+1 for BSpline)
158    /// - Axis coordinates are not strictly increasing
159    /// - The values array shape does not match the axes lengths
160    pub fn new(
161        axes: Vec<Array1<F>>,
162        values: Array<F, IxDyn>,
163        method: TensorProductMethod,
164    ) -> InterpolateResult<Self> {
165        Self::with_boundary(axes, values, method, BoundaryHandling::Clamp)
166    }
167
168    /// Create a new tensor product grid interpolator with boundary handling
169    ///
170    /// # Arguments
171    ///
172    /// * `axes` - 1D coordinate arrays, one per dimension.
173    /// * `values` - N-dimensional array of values on the grid.
174    /// * `method` - Interpolation method.
175    /// * `boundary` - How to handle out-of-bound query points.
176    pub fn with_boundary(
177        axes: Vec<Array1<F>>,
178        values: Array<F, IxDyn>,
179        method: TensorProductMethod,
180        boundary: BoundaryHandling,
181    ) -> InterpolateResult<Self> {
182        let ndim = axes.len();
183
184        if ndim == 0 {
185            return Err(InterpolateError::empty_data(
186                "TensorProductGridInterpolator",
187            ));
188        }
189
190        if ndim != values.ndim() {
191            return Err(InterpolateError::dimension_mismatch(
192                ndim,
193                values.ndim(),
194                "TensorProductGridInterpolator: axes count vs values dimensions",
195            ));
196        }
197
198        let mut shape = Vec::with_capacity(ndim);
199        for (d, axis) in axes.iter().enumerate() {
200            let n = axis.len();
201            if n < 2 {
202                return Err(InterpolateError::insufficient_points(
203                    2,
204                    n,
205                    &format!("TensorProductGridInterpolator axis {}", d),
206                ));
207            }
208
209            // Check strictly increasing
210            for i in 1..n {
211                if axis[i] <= axis[i - 1] {
212                    return Err(InterpolateError::invalid_input(format!(
213                        "Axis {} is not strictly increasing at index {}: {} <= {}",
214                        d,
215                        i,
216                        axis[i],
217                        axis[i - 1]
218                    )));
219                }
220            }
221
222            // Check shape matches
223            if n != values.shape()[d] {
224                return Err(InterpolateError::shape_mismatch(
225                    format!("{}", n),
226                    format!("{}", values.shape()[d]),
227                    format!("axis {} vs values dimension {}", d, d),
228                ));
229            }
230
231            // B-spline degree check
232            if let TensorProductMethod::BSpline { degree } = method {
233                if n < degree + 1 {
234                    return Err(InterpolateError::insufficient_points(
235                        degree + 1,
236                        n,
237                        &format!(
238                            "TensorProductGridInterpolator axis {} for degree-{} B-spline",
239                            d, degree
240                        ),
241                    ));
242                }
243            }
244
245            shape.push(n);
246        }
247
248        // For B-spline method, precompute coefficients
249        let bspline_coeffs = if let TensorProductMethod::BSpline { degree } = method {
250            Some(Self::compute_bspline_coefficients(
251                &axes, &values, &shape, ndim, degree,
252            )?)
253        } else {
254            None
255        };
256
257        Ok(Self {
258            axes,
259            values,
260            method,
261            boundary,
262            ndim,
263            shape,
264            bspline_coeffs,
265        })
266    }
267
268    /// Evaluate the interpolator at a single point
269    ///
270    /// # Arguments
271    ///
272    /// * `point` - Coordinates of the query point, one per dimension
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if the point dimension does not match the grid dimension,
277    /// or if boundary handling is set to Error and the point is outside the grid.
278    pub fn evaluate_point(&self, point: &[F]) -> InterpolateResult<F> {
279        if point.len() != self.ndim {
280            return Err(InterpolateError::dimension_mismatch(
281                self.ndim,
282                point.len(),
283                "TensorProductGridInterpolator::evaluate_point",
284            ));
285        }
286
287        match self.method {
288            TensorProductMethod::Nearest => self.nearest_interpolate(point),
289            TensorProductMethod::Multilinear => self.multilinear_interpolate(point),
290            TensorProductMethod::BSpline { degree } => self.bspline_interpolate(point, degree),
291        }
292    }
293
294    /// Evaluate the interpolator at a single point given as an ArrayView
295    pub fn evaluate_point_array(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
296        let pt: Vec<F> = point.iter().copied().collect();
297        self.evaluate_point(&pt)
298    }
299
300    /// Evaluate the interpolator at multiple points
301    ///
302    /// # Arguments
303    ///
304    /// * `points` - Array of query points, shape (n_queries, n_dims)
305    pub fn evaluate_batch(&self, points: &[Vec<F>]) -> InterpolateResult<Vec<F>> {
306        let mut results = Vec::with_capacity(points.len());
307        for pt in points {
308            results.push(self.evaluate_point(pt)?);
309        }
310        Ok(results)
311    }
312
313    /// Get the number of dimensions
314    pub fn ndim(&self) -> usize {
315        self.ndim
316    }
317
318    /// Get the grid shape
319    pub fn shape(&self) -> &[usize] {
320        &self.shape
321    }
322
323    /// Get a reference to the axes
324    pub fn axes(&self) -> &[Array1<F>] {
325        &self.axes
326    }
327
328    /// Get a reference to the values
329    pub fn values(&self) -> &Array<F, IxDyn> {
330        &self.values
331    }
332
333    // -----------------------------------------------------------------------
334    // Private: locate point on grid
335    // -----------------------------------------------------------------------
336
337    /// Find the cell index and fractional position for a coordinate along one axis
338    /// Returns (cell_index, fraction) where fraction is in [0, 1]
339    fn locate_on_axis(&self, dim: usize, x: F) -> InterpolateResult<(usize, F)> {
340        let axis = &self.axes[dim];
341        let n = axis.len();
342        let lo = axis[0];
343        let hi = axis[n - 1];
344
345        // Handle boundary
346        if x < lo || x > hi {
347            match self.boundary {
348                BoundaryHandling::Error => {
349                    return Err(InterpolateError::OutOfBounds(format!(
350                        "Point coordinate {} in dimension {} is outside grid bounds [{}, {}]",
351                        x, dim, lo, hi
352                    )));
353                }
354                BoundaryHandling::Nan => {
355                    return Ok((0, F::nan()));
356                }
357                BoundaryHandling::Clamp | BoundaryHandling::Extrapolate => {
358                    // For Clamp, we clamp to boundary
359                    // For Extrapolate, we still use the boundary cell but allow fraction outside [0,1]
360                    if x < lo {
361                        if self.boundary == BoundaryHandling::Clamp {
362                            return Ok((0, F::zero()));
363                        } else {
364                            // Extrapolate: compute fraction (will be negative)
365                            let h = axis[1] - axis[0];
366                            let frac = if h > F::zero() {
367                                (x - lo) / h
368                            } else {
369                                F::zero()
370                            };
371                            return Ok((0, frac));
372                        }
373                    } else {
374                        if self.boundary == BoundaryHandling::Clamp {
375                            return Ok((n - 2, F::one()));
376                        } else {
377                            let h = axis[n - 1] - axis[n - 2];
378                            let frac = if h > F::zero() {
379                                (x - axis[n - 2]) / h
380                            } else {
381                                F::one()
382                            };
383                            return Ok((n - 2, frac));
384                        }
385                    }
386                }
387            }
388        }
389
390        // Binary search for the cell containing x
391        let mut lo_idx = 0usize;
392        let mut hi_idx = n - 1;
393
394        while hi_idx - lo_idx > 1 {
395            let mid = (lo_idx + hi_idx) / 2;
396            if x < axis[mid] {
397                hi_idx = mid;
398            } else {
399                lo_idx = mid;
400            }
401        }
402
403        // lo_idx is now the cell index (x is between axis[lo_idx] and axis[hi_idx])
404        let cell_lo = axis[lo_idx];
405        let cell_hi = axis[hi_idx];
406        let h = cell_hi - cell_lo;
407
408        let frac = if h > F::zero() {
409            (x - cell_lo) / h
410        } else {
411            F::zero()
412        };
413
414        Ok((lo_idx, frac))
415    }
416
417    // -----------------------------------------------------------------------
418    // Nearest interpolation
419    // -----------------------------------------------------------------------
420
421    fn nearest_interpolate(&self, point: &[F]) -> InterpolateResult<F> {
422        let mut idx = Vec::with_capacity(self.ndim);
423
424        for d in 0..self.ndim {
425            let (cell, frac) = self.locate_on_axis(d, point[d])?;
426            if frac.is_nan() {
427                return Ok(F::nan());
428            }
429            // Pick the nearer grid point
430            let half = F::from_f64(0.5).unwrap_or_else(|| F::one() / (F::one() + F::one()));
431            if frac <= half {
432                idx.push(cell);
433            } else {
434                idx.push((cell + 1).min(self.shape[d] - 1));
435            }
436        }
437
438        Ok(self.values[idx.as_slice()])
439    }
440
441    // -----------------------------------------------------------------------
442    // Multilinear interpolation
443    // -----------------------------------------------------------------------
444
445    fn multilinear_interpolate(&self, point: &[F]) -> InterpolateResult<F> {
446        let mut cells = Vec::with_capacity(self.ndim);
447        let mut fracs = Vec::with_capacity(self.ndim);
448
449        for d in 0..self.ndim {
450            let (cell, frac) = self.locate_on_axis(d, point[d])?;
451            if frac.is_nan() {
452                return Ok(F::nan());
453            }
454            cells.push(cell);
455            fracs.push(frac);
456        }
457
458        // Compute the multilinear interpolation by iterating over all 2^ndim vertices
459        // of the hypercube defined by the cell
460        let n_vertices = 1usize << self.ndim;
461        let mut result = F::zero();
462
463        for vertex in 0..n_vertices {
464            let mut vertex_idx = Vec::with_capacity(self.ndim);
465            let mut weight = F::one();
466
467            for d in 0..self.ndim {
468                let use_upper = (vertex >> d) & 1 == 1;
469                let idx = cells[d] + if use_upper { 1 } else { 0 };
470                // Safety: the cell index guarantees idx and idx+1 are valid
471                vertex_idx.push(idx.min(self.shape[d] - 1));
472
473                weight = weight
474                    * if use_upper {
475                        fracs[d]
476                    } else {
477                        F::one() - fracs[d]
478                    };
479            }
480
481            result = result + weight * self.values[vertex_idx.as_slice()];
482        }
483
484        Ok(result)
485    }
486
487    // -----------------------------------------------------------------------
488    // Tensor product B-spline interpolation
489    // -----------------------------------------------------------------------
490
491    /// Compute the B-spline coefficients by solving the tensor product system
492    ///
493    /// For each dimension, we solve a 1D B-spline fitting problem along that
494    /// axis while keeping all other indices fixed. This is done dimension by
495    /// dimension.
496    fn compute_bspline_coefficients(
497        axes: &[Array1<F>],
498        values: &Array<F, IxDyn>,
499        shape: &[usize],
500        ndim: usize,
501        degree: usize,
502    ) -> InterpolateResult<Array<F, IxDyn>> {
503        // Start with the original values
504        let mut coeffs = values.clone();
505
506        // Process each dimension
507        for d in 0..ndim {
508            let n = shape[d];
509            let axis = &axes[d];
510
511            // Create the B-spline basis matrix for this axis
512            let knots = Self::create_clamped_knots(axis, degree);
513            let basis = Self::compute_bspline_basis_matrix(axis, &knots, degree)?;
514
515            // Solve the linear system along this dimension for each "fiber"
516            // A fiber is obtained by fixing all indices except dimension d
517            let total_fibers: usize = shape
518                .iter()
519                .enumerate()
520                .filter(|&(i, _)| i != d)
521                .map(|(_, &s)| s)
522                .product::<usize>()
523                .max(1);
524
525            // For each fiber, extract the 1D data, solve, and put back
526            let mut multi_idx = vec![0usize; ndim];
527            for _fiber in 0..total_fibers {
528                // Extract the 1D slice along dimension d
529                let mut fiber_vals = Vec::with_capacity(n);
530                for k in 0..n {
531                    multi_idx[d] = k;
532                    fiber_vals.push(coeffs[multi_idx.as_slice()]);
533                }
534
535                // Solve the 1D B-spline system: basis * c = fiber_vals
536                let solved = Self::solve_bspline_system(&basis, &fiber_vals, n)?;
537
538                // Write back
539                for k in 0..n {
540                    multi_idx[d] = k;
541                    *coeffs.get_mut(multi_idx.as_slice()).ok_or_else(|| {
542                        InterpolateError::IndexError(format!("Index {:?} out of bounds", multi_idx))
543                    })? = solved[k];
544                }
545
546                // Advance the multi-index (skip dimension d)
547                Self::advance_multi_index(&mut multi_idx, shape, d);
548            }
549        }
550
551        Ok(coeffs)
552    }
553
554    /// Advance a multi-index by incrementing all dimensions except `skip_dim`
555    fn advance_multi_index(idx: &mut [usize], shape: &[usize], skip_dim: usize) {
556        for d in 0..idx.len() {
557            if d == skip_dim {
558                continue;
559            }
560            idx[d] += 1;
561            if idx[d] < shape[d] {
562                return;
563            }
564            idx[d] = 0;
565        }
566    }
567
568    /// Create clamped knot vector for B-spline interpolation
569    ///
570    /// For n data points and degree p, the clamped knot vector has:
571    /// - (p+1) copies of the first coordinate
572    /// - (n-p-1) interior knots (averaging the data points)
573    /// - (p+1) copies of the last coordinate
574    /// Total: n + p + 1 knots
575    fn create_clamped_knots(axis: &Array1<F>, degree: usize) -> Vec<F> {
576        let n = axis.len();
577        let p = degree;
578        let n_knots = n + p + 1;
579        let mut knots = Vec::with_capacity(n_knots);
580
581        // (p+1) copies of first value
582        for _ in 0..=p {
583            knots.push(axis[0]);
584        }
585
586        // Interior knots: use averaging of data points (de Boor approach)
587        if n > p + 1 {
588            for j in 1..(n - p) {
589                let mut sum = F::zero();
590                for i in j..(j + p) {
591                    sum = sum + axis[i];
592                }
593                let p_f = F::from_usize(p).unwrap_or_else(|| F::one());
594                knots.push(sum / p_f);
595            }
596        }
597
598        // (p+1) copies of last value
599        for _ in 0..=p {
600            knots.push(axis[n - 1]);
601        }
602
603        knots
604    }
605
606    /// Compute the B-spline basis matrix: B[i][j] = B_{j,degree}(axis[i])
607    fn compute_bspline_basis_matrix(
608        axis: &Array1<F>,
609        knots: &[F],
610        degree: usize,
611    ) -> InterpolateResult<Vec<Vec<F>>> {
612        let n = axis.len();
613        let n_basis = n; // n basis functions for n data points
614        let mut matrix = vec![vec![F::zero(); n_basis]; n];
615
616        for i in 0..n {
617            let x = axis[i];
618            for j in 0..n_basis {
619                matrix[i][j] = Self::bspline_basis_robust(j, degree, x, knots, n_basis);
620            }
621        }
622
623        Ok(matrix)
624    }
625
626    /// Evaluate B-spline basis function B_{i,k}(x) using de Boor recursion
627    /// with robust handling of the right endpoint
628    fn bspline_basis_robust(i: usize, k: usize, x: F, knots: &[F], n_basis: usize) -> F {
629        if k == 0 {
630            if i + 1 >= knots.len() {
631                return F::zero();
632            }
633            // Standard indicator: [knots[i], knots[i+1])
634            if x >= knots[i] && x < knots[i + 1] {
635                return F::one();
636            }
637            // Special handling for last basis function at the right endpoint:
638            // The last basis function should be 1 at x = knots.last()
639            if i == n_basis - 1 && x == knots[i + 1] {
640                return F::one();
641            }
642            return F::zero();
643        }
644
645        let mut result = F::zero();
646
647        // Left term: (x - t_i) / (t_{i+k} - t_i) * B_{i,k-1}(x)
648        if i + k < knots.len() {
649            let denom = knots[i + k] - knots[i];
650            if denom > F::zero() {
651                let left = Self::bspline_basis_robust(i, k - 1, x, knots, n_basis);
652                result = result + (x - knots[i]) / denom * left;
653            }
654        }
655
656        // Right term: (t_{i+k+1} - x) / (t_{i+k+1} - t_{i+1}) * B_{i+1,k-1}(x)
657        if i + k + 1 < knots.len() {
658            let denom = knots[i + k + 1] - knots[i + 1];
659            if denom > F::zero() {
660                let right = Self::bspline_basis_robust(i + 1, k - 1, x, knots, n_basis);
661                result = result + (knots[i + k + 1] - x) / denom * right;
662            }
663        }
664
665        result
666    }
667
668    /// Solve a banded-like linear system B * c = f using simple Gaussian elimination
669    fn solve_bspline_system(matrix: &[Vec<F>], rhs: &[F], n: usize) -> InterpolateResult<Vec<F>> {
670        // Build augmented matrix
671        let mut aug: Vec<Vec<F>> = Vec::with_capacity(n);
672        for i in 0..n {
673            let mut row = Vec::with_capacity(n + 1);
674            for j in 0..n {
675                row.push(matrix[i][j]);
676            }
677            row.push(rhs[i]);
678            aug.push(row);
679        }
680
681        let eps = F::from_f64(1e-14).unwrap_or_else(|| F::epsilon());
682
683        // Forward elimination with partial pivoting
684        for col in 0..n {
685            // Find pivot
686            let mut max_val = aug[col][col].abs();
687            let mut max_row = col;
688            for row in (col + 1)..n {
689                let val = aug[row][col].abs();
690                if val > max_val {
691                    max_val = val;
692                    max_row = row;
693                }
694            }
695
696            if max_val < eps {
697                return Err(InterpolateError::numerical_error(
698                    "Singular B-spline basis matrix; cannot compute coefficients",
699                ));
700            }
701
702            // Swap rows
703            if max_row != col {
704                aug.swap(col, max_row);
705            }
706
707            // Eliminate
708            let pivot = aug[col][col];
709            for row in (col + 1)..n {
710                let factor = aug[row][col] / pivot;
711                for j in col..=n {
712                    let val = aug[col][j];
713                    aug[row][j] = aug[row][j] - factor * val;
714                }
715            }
716        }
717
718        // Back substitution
719        let mut result = vec![F::zero(); n];
720        for i in (0..n).rev() {
721            let mut sum = aug[i][n];
722            for j in (i + 1)..n {
723                sum = sum - aug[i][j] * result[j];
724            }
725            let diag = aug[i][i];
726            if diag.abs() < eps {
727                return Err(InterpolateError::numerical_error(
728                    "Zero diagonal in back substitution",
729                ));
730            }
731            result[i] = sum / diag;
732        }
733
734        Ok(result)
735    }
736
737    /// Evaluate tensor product B-spline at a point using precomputed coefficients
738    fn bspline_interpolate(&self, point: &[F], degree: usize) -> InterpolateResult<F> {
739        let coeffs = self.bspline_coeffs.as_ref().ok_or_else(|| {
740            InterpolateError::InvalidState("B-spline coefficients not computed".to_string())
741        })?;
742
743        // For each dimension, compute the B-spline basis values at the query coordinate
744        let mut basis_vals: Vec<Vec<(usize, F)>> = Vec::with_capacity(self.ndim);
745
746        for d in 0..self.ndim {
747            let axis = &self.axes[d];
748            let knots = Self::create_clamped_knots(axis, degree);
749            let n = axis.len();
750
751            // Clamp point to grid bounds
752            let x =
753                match self.boundary {
754                    BoundaryHandling::Error => {
755                        if point[d] < axis[0] || point[d] > axis[n - 1] {
756                            return Err(InterpolateError::OutOfBounds(format!(
757                            "Point coordinate {} in dimension {} is outside grid bounds [{}, {}]",
758                            point[d], d, axis[0], axis[n - 1]
759                        )));
760                        }
761                        point[d]
762                    }
763                    BoundaryHandling::Nan => {
764                        if point[d] < axis[0] || point[d] > axis[n - 1] {
765                            return Ok(F::nan());
766                        }
767                        point[d]
768                    }
769                    BoundaryHandling::Clamp => point[d].max(axis[0]).min(axis[n - 1]),
770                    BoundaryHandling::Extrapolate => point[d],
771                };
772
773            // Compute non-zero basis functions at x
774            let mut vals = Vec::new();
775            for j in 0..n {
776                let b = Self::bspline_basis_robust(j, degree, x, &knots, n);
777                if b.abs() > F::epsilon() {
778                    vals.push((j, b));
779                }
780            }
781
782            // If no basis functions are non-zero (edge case), use nearest
783            if vals.is_empty() {
784                // Find nearest grid point
785                let mut nearest = 0;
786                let mut min_d = (x - axis[0]).abs();
787                for j in 1..n {
788                    let dist = (x - axis[j]).abs();
789                    if dist < min_d {
790                        min_d = dist;
791                        nearest = j;
792                    }
793                }
794                vals.push((nearest, F::one()));
795            }
796
797            basis_vals.push(vals);
798        }
799
800        // Compute the tensor product sum:
801        // f(x) = sum_{j1,..,jN} c[j1,..,jN] * B_{j1}(x1) * ... * B_{jN}(xN)
802        // Only iterate over combinations where all basis values are non-zero
803        self.tensor_product_sum(coeffs, &basis_vals, 0, &mut vec![0usize; self.ndim])
804    }
805
806    /// Recursively compute tensor product sum over non-zero basis function indices
807    fn tensor_product_sum(
808        &self,
809        coeffs: &Array<F, IxDyn>,
810        basis_vals: &[Vec<(usize, F)>],
811        dim: usize,
812        idx: &mut Vec<usize>,
813    ) -> InterpolateResult<F> {
814        if dim == self.ndim {
815            // All dimensions have been indexed; get the coefficient
816            return Ok(coeffs[idx.as_slice()]);
817        }
818
819        let mut sum = F::zero();
820        for &(j, b) in &basis_vals[dim] {
821            idx[dim] = j;
822            let inner = self.tensor_product_sum(coeffs, basis_vals, dim + 1, idx)?;
823            sum = sum + b * inner;
824        }
825
826        Ok(sum)
827    }
828}
829
830// ---------------------------------------------------------------------------
831// Convenience constructors
832// ---------------------------------------------------------------------------
833
834/// Create a multilinear interpolator on a tensor product grid
835///
836/// # Arguments
837///
838/// * `axes` - 1D coordinate arrays for each dimension
839/// * `values` - Values on the grid
840///
841/// # Examples
842///
843/// ```rust
844/// use scirs2_core::ndarray::{Array, Array1, IxDyn};
845/// use scirs2_interpolate::tensor_product::make_multilinear_interpolator;
846///
847/// let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
848/// let y = Array1::from_vec(vec![0.0, 1.0]);
849/// let mut values = Array::zeros(IxDyn(&[3, 2]));
850/// for i in 0..3 {
851///     for j in 0..2 {
852///         values[[i, j].as_slice()] = (i + j) as f64;
853///     }
854/// }
855///
856/// let interp = make_multilinear_interpolator(vec![x, y], values).expect("valid");
857/// ```
858pub fn make_multilinear_interpolator<
859    F: Float
860        + FromPrimitive
861        + Debug
862        + Display
863        + AddAssign
864        + SubAssign
865        + MulAssign
866        + DivAssign
867        + RemAssign
868        + scirs2_core::numeric::Zero
869        + 'static,
870>(
871    axes: Vec<Array1<F>>,
872    values: Array<F, IxDyn>,
873) -> InterpolateResult<TensorProductGridInterpolator<F>> {
874    TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
875}
876
877/// Create a tensor product B-spline interpolator
878///
879/// # Arguments
880///
881/// * `axes` - 1D coordinate arrays for each dimension
882/// * `values` - Values on the grid
883/// * `degree` - B-spline degree (1=linear, 3=cubic)
884pub fn make_tensor_bspline_interpolator<
885    F: Float
886        + FromPrimitive
887        + Debug
888        + Display
889        + AddAssign
890        + SubAssign
891        + MulAssign
892        + DivAssign
893        + RemAssign
894        + scirs2_core::numeric::Zero
895        + 'static,
896>(
897    axes: Vec<Array1<F>>,
898    values: Array<F, IxDyn>,
899    degree: usize,
900) -> InterpolateResult<TensorProductGridInterpolator<F>> {
901    TensorProductGridInterpolator::new(axes, values, TensorProductMethod::BSpline { degree })
902}
903
904// ---------------------------------------------------------------------------
905// Tests
906// ---------------------------------------------------------------------------
907
908#[cfg(test)]
909mod tests {
910    use super::*;
911    use scirs2_core::ndarray::{Array, Array1, IxDyn};
912
913    fn make_2d_linear_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
914        // z = x + 2y on a 4x3 grid
915        let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
916        let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
917        let mut values = Array::zeros(IxDyn(&[4, 3]));
918        for i in 0..4 {
919            for j in 0..3 {
920                values[[i, j].as_slice()] = x[i] + 2.0 * y[j];
921            }
922        }
923        (vec![x, y], values)
924    }
925
926    fn make_2d_nonuniform_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
927        // Non-uniform spacing: z = x * y
928        let x = Array1::from_vec(vec![0.0, 0.5, 1.0, 2.0, 4.0]);
929        let y = Array1::from_vec(vec![0.0, 0.1, 1.0, 3.0]);
930        let mut values = Array::zeros(IxDyn(&[5, 4]));
931        for i in 0..5 {
932            for j in 0..4 {
933                values[[i, j].as_slice()] = x[i] * y[j];
934            }
935        }
936        (vec![x, y], values)
937    }
938
939    fn make_3d_grid() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
940        // z = x + y + z on a 3x3x3 grid
941        let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
942        let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
943        let z = Array1::from_vec(vec![0.0, 1.0, 2.0]);
944        let mut values = Array::zeros(IxDyn(&[3, 3, 3]));
945        for i in 0..3 {
946            for j in 0..3 {
947                for k in 0..3 {
948                    values[[i, j, k].as_slice()] = x[i] + y[j] + z[k];
949                }
950            }
951        }
952        (vec![x, y, z], values)
953    }
954
955    // === Multilinear interpolation tests ===
956
957    #[test]
958    fn test_multilinear_at_grid_points() {
959        let (axes, values) = make_2d_linear_grid();
960        let interp = TensorProductGridInterpolator::new(
961            axes.clone(),
962            values.clone(),
963            TensorProductMethod::Multilinear,
964        )
965        .expect("valid");
966
967        // Test at every grid point
968        for i in 0..4 {
969            for j in 0..3 {
970                let result = interp
971                    .evaluate_point(&[axes[0][i], axes[1][j]])
972                    .expect("valid");
973                let expected = values[[i, j].as_slice()];
974                assert!(
975                    (result - expected).abs() < 1e-12,
976                    "At grid point ({}, {}): expected {}, got {}",
977                    i,
978                    j,
979                    expected,
980                    result
981                );
982            }
983        }
984    }
985
986    #[test]
987    fn test_multilinear_reproduces_linear_function() {
988        let (axes, values) = make_2d_linear_grid();
989        let interp =
990            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
991                .expect("valid");
992
993        // Multilinear interpolation should reproduce linear functions exactly
994        let test_points = vec![(0.5, 0.5), (1.5, 1.5), (2.5, 1.0), (0.3, 1.7)];
995        for (x, y) in test_points {
996            let result = interp.evaluate_point(&[x, y]).expect("valid");
997            let expected = x + 2.0 * y;
998            assert!(
999                (result - expected).abs() < 1e-10,
1000                "Multilinear at ({}, {}): expected {}, got {}",
1001                x,
1002                y,
1003                expected,
1004                result
1005            );
1006        }
1007    }
1008
1009    #[test]
1010    fn test_multilinear_nonuniform_grid() {
1011        let (axes, values) = make_2d_nonuniform_grid();
1012        let interp =
1013            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1014                .expect("valid");
1015
1016        // Test at a known interior point
1017        // Between x=0.5 and x=1.0, y=0.1 and y=1.0
1018        let result = interp.evaluate_point(&[0.75, 0.55]).expect("valid");
1019        // Bilinear interpolation of x*y at (0.75, 0.55):
1020        // x fraction: (0.75 - 0.5) / (1.0 - 0.5) = 0.5
1021        // y fraction: (0.55 - 0.1) / (1.0 - 0.1) = 0.5
1022        // Corners: (0.5,0.1)=0.05, (0.5,1.0)=0.5, (1.0,0.1)=0.1, (1.0,1.0)=1.0
1023        // Result: 0.25*0.05 + 0.25*0.5 + 0.25*0.1 + 0.25*1.0 = 0.4125
1024        assert!(
1025            (result - 0.4125).abs() < 1e-10,
1026            "Nonuniform bilinear: expected 0.4125, got {}",
1027            result
1028        );
1029    }
1030
1031    #[test]
1032    fn test_multilinear_3d() {
1033        let (axes, values) = make_3d_grid();
1034        let interp =
1035            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1036                .expect("valid");
1037
1038        // Linear function should be reproduced exactly
1039        let result = interp.evaluate_point(&[0.5, 1.5, 0.5]).expect("valid");
1040        let expected = 0.5 + 1.5 + 0.5;
1041        assert!(
1042            (result - expected).abs() < 1e-10,
1043            "3D multilinear at (0.5, 1.5, 0.5): expected {}, got {}",
1044            expected,
1045            result
1046        );
1047    }
1048
1049    // === Nearest interpolation tests ===
1050
1051    #[test]
1052    fn test_nearest_at_grid_points() {
1053        let (axes, values) = make_2d_linear_grid();
1054        let interp = TensorProductGridInterpolator::new(
1055            axes.clone(),
1056            values.clone(),
1057            TensorProductMethod::Nearest,
1058        )
1059        .expect("valid");
1060
1061        for i in 0..4 {
1062            for j in 0..3 {
1063                let result = interp
1064                    .evaluate_point(&[axes[0][i], axes[1][j]])
1065                    .expect("valid");
1066                let expected = values[[i, j].as_slice()];
1067                assert!(
1068                    (result - expected).abs() < 1e-12,
1069                    "Nearest at grid point ({}, {}): expected {}, got {}",
1070                    i,
1071                    j,
1072                    expected,
1073                    result
1074                );
1075            }
1076        }
1077    }
1078
1079    #[test]
1080    fn test_nearest_between_points() {
1081        let (axes, values) = make_2d_linear_grid();
1082        let interp = TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Nearest)
1083            .expect("valid");
1084
1085        // (0.3, 0.3) is closest to grid point (0, 0) => value = 0+0 = 0
1086        let result = interp.evaluate_point(&[0.3, 0.3]).expect("valid");
1087        assert!(
1088            (result - 0.0).abs() < 1e-10,
1089            "Nearest at (0.3, 0.3): expected 0.0, got {}",
1090            result
1091        );
1092
1093        // (2.7, 1.7) is closest to grid point (3, 2) => value = 3+4 = 7
1094        let result = interp.evaluate_point(&[2.7, 1.7]).expect("valid");
1095        assert!(
1096            (result - 7.0).abs() < 1e-10,
1097            "Nearest at (2.7, 1.7): expected 7.0, got {}",
1098            result
1099        );
1100    }
1101
1102    // === B-spline interpolation tests ===
1103
1104    #[test]
1105    fn test_bspline_linear_at_grid_points() {
1106        let (axes, values) = make_2d_linear_grid();
1107        let interp = TensorProductGridInterpolator::new(
1108            axes.clone(),
1109            values.clone(),
1110            TensorProductMethod::BSpline { degree: 1 },
1111        )
1112        .expect("valid");
1113
1114        // Degree-1 B-spline should reproduce grid values exactly
1115        for i in 0..4 {
1116            for j in 0..3 {
1117                let result = interp
1118                    .evaluate_point(&[axes[0][i], axes[1][j]])
1119                    .expect("valid");
1120                let expected = values[[i, j].as_slice()];
1121                assert!(
1122                    (result - expected).abs() < 1e-8,
1123                    "BSpline(1) at grid ({}, {}): expected {}, got {}",
1124                    i,
1125                    j,
1126                    expected,
1127                    result
1128                );
1129            }
1130        }
1131    }
1132
1133    fn make_2d_linear_grid_4x4() -> (Vec<Array1<f64>>, Array<f64, IxDyn>) {
1134        // z = x + 2y on a 4x4 grid (enough for cubic B-spline)
1135        let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1136        let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1137        let mut values = Array::zeros(IxDyn(&[4, 4]));
1138        for i in 0..4 {
1139            for j in 0..4 {
1140                values[[i, j].as_slice()] = x[i] + 2.0 * y[j];
1141            }
1142        }
1143        (vec![x, y], values)
1144    }
1145
1146    #[test]
1147    fn test_bspline_cubic_at_grid_points() {
1148        let (axes, values) = make_2d_linear_grid_4x4();
1149        let interp = TensorProductGridInterpolator::new(
1150            axes.clone(),
1151            values.clone(),
1152            TensorProductMethod::BSpline { degree: 3 },
1153        )
1154        .expect("valid");
1155
1156        // Cubic B-spline should reproduce grid values exactly
1157        for i in 0..4 {
1158            for j in 0..4 {
1159                let result = interp
1160                    .evaluate_point(&[axes[0][i], axes[1][j]])
1161                    .expect("valid");
1162                let expected = values[[i, j].as_slice()];
1163                assert!(
1164                    (result - expected).abs() < 1e-6,
1165                    "BSpline(3) at grid ({}, {}): expected {}, got {}",
1166                    i,
1167                    j,
1168                    expected,
1169                    result
1170                );
1171            }
1172        }
1173    }
1174
1175    #[test]
1176    fn test_bspline_cubic_interior_points() {
1177        let (axes, values) = make_2d_linear_grid_4x4();
1178        let interp = TensorProductGridInterpolator::new(
1179            axes,
1180            values,
1181            TensorProductMethod::BSpline { degree: 3 },
1182        )
1183        .expect("valid");
1184
1185        // Cubic B-spline should also reproduce linear functions well
1186        let result = interp.evaluate_point(&[1.5, 0.5]).expect("valid");
1187        let expected = 1.5 + 2.0 * 0.5;
1188        assert!(
1189            (result - expected).abs() < 0.5,
1190            "BSpline(3) at (1.5, 0.5): expected {}, got {}",
1191            expected,
1192            result
1193        );
1194    }
1195
1196    // === Boundary handling tests ===
1197
1198    #[test]
1199    fn test_boundary_clamp() {
1200        let (axes, values) = make_2d_linear_grid();
1201        let interp = TensorProductGridInterpolator::with_boundary(
1202            axes,
1203            values,
1204            TensorProductMethod::Multilinear,
1205            BoundaryHandling::Clamp,
1206        )
1207        .expect("valid");
1208
1209        // Point outside grid gets clamped
1210        let result = interp.evaluate_point(&[-1.0, -1.0]).expect("valid");
1211        // Clamped to (0, 0) => 0 + 0 = 0
1212        assert!(
1213            (result - 0.0).abs() < 1e-10,
1214            "Clamped at (-1,-1): expected 0.0, got {}",
1215            result
1216        );
1217
1218        let result = interp.evaluate_point(&[10.0, 10.0]).expect("valid");
1219        // Clamped to (3, 2) => 3 + 4 = 7
1220        assert!(
1221            (result - 7.0).abs() < 1e-10,
1222            "Clamped at (10,10): expected 7.0, got {}",
1223            result
1224        );
1225    }
1226
1227    #[test]
1228    fn test_boundary_error() {
1229        let (axes, values) = make_2d_linear_grid();
1230        let interp = TensorProductGridInterpolator::with_boundary(
1231            axes,
1232            values,
1233            TensorProductMethod::Multilinear,
1234            BoundaryHandling::Error,
1235        )
1236        .expect("valid");
1237
1238        let result = interp.evaluate_point(&[-1.0, 0.5]);
1239        assert!(result.is_err(), "Should error for out-of-bounds point");
1240    }
1241
1242    #[test]
1243    fn test_boundary_nan() {
1244        let (axes, values) = make_2d_linear_grid();
1245        let interp = TensorProductGridInterpolator::with_boundary(
1246            axes,
1247            values,
1248            TensorProductMethod::Multilinear,
1249            BoundaryHandling::Nan,
1250        )
1251        .expect("valid");
1252
1253        let result = interp.evaluate_point(&[-1.0, 0.5]).expect("valid");
1254        assert!(result.is_nan(), "Should return NaN for out-of-bounds point");
1255    }
1256
1257    #[test]
1258    fn test_boundary_extrapolate() {
1259        let (axes, values) = make_2d_linear_grid();
1260        let interp = TensorProductGridInterpolator::with_boundary(
1261            axes,
1262            values,
1263            TensorProductMethod::Multilinear,
1264            BoundaryHandling::Extrapolate,
1265        )
1266        .expect("valid");
1267
1268        // For a linear function, extrapolation should give the correct value
1269        let result = interp.evaluate_point(&[-0.5, 0.5]).expect("valid");
1270        // z = x + 2y at (-0.5, 0.5) = -0.5 + 1.0 = 0.5
1271        assert!(
1272            (result - 0.5).abs() < 1e-10,
1273            "Extrapolated at (-0.5, 0.5): expected 0.5, got {}",
1274            result
1275        );
1276    }
1277
1278    // === Batch evaluation tests ===
1279
1280    #[test]
1281    fn test_batch_evaluation() {
1282        let (axes, values) = make_2d_linear_grid();
1283        let interp =
1284            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1285                .expect("valid");
1286
1287        let points = vec![vec![0.5, 0.5], vec![1.5, 1.0], vec![2.0, 1.5]];
1288        let results = interp.evaluate_batch(&points).expect("valid");
1289
1290        assert_eq!(results.len(), 3);
1291        assert!((results[0] - (0.5 + 1.0)).abs() < 1e-10);
1292        assert!((results[1] - (1.5 + 2.0)).abs() < 1e-10);
1293        assert!((results[2] - (2.0 + 3.0)).abs() < 1e-10);
1294    }
1295
1296    // === Edge case tests ===
1297
1298    #[test]
1299    fn test_empty_axes_rejected() {
1300        let axes: Vec<Array1<f64>> = vec![];
1301        let values = Array::zeros(IxDyn(&[]));
1302        let result =
1303            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear);
1304        assert!(result.is_err(), "Empty axes should be rejected");
1305    }
1306
1307    #[test]
1308    fn test_too_few_points_rejected() {
1309        let x = Array1::from_vec(vec![0.0]); // Only 1 point
1310        let values = Array::zeros(IxDyn(&[1]));
1311        let result =
1312            TensorProductGridInterpolator::new(vec![x], values, TensorProductMethod::Multilinear);
1313        assert!(result.is_err(), "Single-point axis should be rejected");
1314    }
1315
1316    #[test]
1317    fn test_nonsorted_axis_rejected() {
1318        let x = Array1::from_vec(vec![0.0, 2.0, 1.0]); // Not sorted
1319        let y = Array1::from_vec(vec![0.0, 1.0]);
1320        let values = Array::zeros(IxDyn(&[3, 2]));
1321        let result = TensorProductGridInterpolator::new(
1322            vec![x, y],
1323            values,
1324            TensorProductMethod::Multilinear,
1325        );
1326        assert!(result.is_err(), "Non-sorted axis should be rejected");
1327    }
1328
1329    #[test]
1330    fn test_shape_mismatch_rejected() {
1331        let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
1332        let y = Array1::from_vec(vec![0.0, 1.0]);
1333        let values = Array::zeros(IxDyn(&[3, 3])); // Wrong shape: should be (3, 2)
1334        let result = TensorProductGridInterpolator::new(
1335            vec![x, y],
1336            values,
1337            TensorProductMethod::Multilinear,
1338        );
1339        assert!(result.is_err(), "Shape mismatch should be rejected");
1340    }
1341
1342    #[test]
1343    fn test_wrong_dimension_query_rejected() {
1344        let (axes, values) = make_2d_linear_grid();
1345        let interp =
1346            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1347                .expect("valid");
1348
1349        let result = interp.evaluate_point(&[1.0]); // 1D query for 2D grid
1350        assert!(result.is_err(), "Wrong dimension query should be rejected");
1351    }
1352
1353    // === Accessor tests ===
1354
1355    #[test]
1356    fn test_accessors() {
1357        let (axes, values) = make_2d_linear_grid();
1358        let interp =
1359            TensorProductGridInterpolator::new(axes, values, TensorProductMethod::Multilinear)
1360                .expect("valid");
1361
1362        assert_eq!(interp.ndim(), 2);
1363        assert_eq!(interp.shape(), &[4, 3]);
1364        assert_eq!(interp.axes().len(), 2);
1365    }
1366
1367    // === Convenience constructor tests ===
1368
1369    #[test]
1370    fn test_make_multilinear_interpolator() {
1371        let (axes, values) = make_2d_linear_grid();
1372        let interp = make_multilinear_interpolator(axes, values).expect("valid");
1373        let result = interp.evaluate_point(&[1.0, 1.0]).expect("valid");
1374        assert!((result - 3.0).abs() < 1e-10);
1375    }
1376
1377    #[test]
1378    fn test_make_tensor_bspline_interpolator() {
1379        let (axes, values) = make_2d_linear_grid();
1380        let interp = make_tensor_bspline_interpolator(axes, values, 1).expect("valid");
1381        let result = interp.evaluate_point(&[1.0, 1.0]).expect("valid");
1382        assert!(
1383            (result - 3.0).abs() < 1e-6,
1384            "BSpline at (1,1): expected 3.0, got {}",
1385            result
1386        );
1387    }
1388
1389    // === Convergence tests ===
1390
1391    #[test]
1392    fn test_multilinear_convergence_quadratic() {
1393        // For f(x,y) = x^2 + y^2 (not linear), multilinear interpolation
1394        // should converge as the grid is refined.
1395        // Use an off-grid test point to avoid zero error from hitting a grid node
1396        let test_point = [0.37_f64, 0.63];
1397        let exact_value = 0.37 * 0.37 + 0.63 * 0.63;
1398
1399        let mut errors = Vec::new();
1400        for &n in &[5, 10, 20, 40] {
1401            let x = Array1::linspace(0.0, 1.0, n);
1402            let y = Array1::linspace(0.0, 1.0, n);
1403            let mut values = Array::zeros(IxDyn(&[n, n]));
1404            for i in 0..n {
1405                for j in 0..n {
1406                    values[[i, j].as_slice()] = x[i] * x[i] + y[j] * y[j];
1407                }
1408            }
1409
1410            let interp = TensorProductGridInterpolator::new(
1411                vec![x, y],
1412                values,
1413                TensorProductMethod::Multilinear,
1414            )
1415            .expect("valid");
1416
1417            let result = interp.evaluate_point(&test_point).expect("valid");
1418            let error = (result - exact_value).abs();
1419            errors.push(error);
1420        }
1421
1422        // Overall error should decrease with refinement
1423        assert!(
1424            errors[errors.len() - 1] < errors[0],
1425            "Error should decrease: first={}, last={}",
1426            errors[0],
1427            errors[errors.len() - 1]
1428        );
1429
1430        assert!(
1431            errors[errors.len() - 1] < 0.01,
1432            "Multilinear should converge to the exact value: final error = {}",
1433            errors[errors.len() - 1]
1434        );
1435    }
1436
1437    // === 1D test ===
1438
1439    #[test]
1440    fn test_1d_multilinear() {
1441        let x = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1442        let mut values = Array::zeros(IxDyn(&[4]));
1443        for i in 0..4 {
1444            values[[i].as_slice()] = (i as f64) * (i as f64); // x^2
1445        }
1446
1447        let interp =
1448            TensorProductGridInterpolator::new(vec![x], values, TensorProductMethod::Multilinear)
1449                .expect("valid");
1450
1451        // At x=0.5: linear interp between 0 and 1 = 0.5
1452        let result = interp.evaluate_point(&[0.5]).expect("valid");
1453        assert!(
1454            (result - 0.5).abs() < 1e-10,
1455            "1D multilinear at 0.5: expected 0.5, got {}",
1456            result
1457        );
1458    }
1459
1460    // === BSpline degree check ===
1461
1462    #[test]
1463    fn test_bspline_insufficient_points_for_degree() {
1464        // Degree 3 needs at least 4 points per axis
1465        let x = Array1::from_vec(vec![0.0, 1.0, 2.0]); // Only 3 points
1466        let y = Array1::from_vec(vec![0.0, 1.0, 2.0]);
1467        let values = Array::zeros(IxDyn(&[3, 3]));
1468        let result = TensorProductGridInterpolator::new(
1469            vec![x, y],
1470            values,
1471            TensorProductMethod::BSpline { degree: 3 },
1472        );
1473        assert!(
1474            result.is_err(),
1475            "Should reject degree 3 with only 3 points per axis"
1476        );
1477    }
1478}