scirs2_interpolate/interpnd.rs
1//! N-dimensional interpolation methods
2//!
3//! This module provides functionality for interpolating multidimensional data.
4
5use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
6use crate::error::{InterpolateError, InterpolateResult};
7use scirs2_core::ndarray::{Array, Array1, Array2, ArrayView1, ArrayView2, IxDyn};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::{Debug, Display};
10use std::ops::{AddAssign, SubAssign};
11
12/// Available grid types for N-dimensional interpolation
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum GridType {
15 /// Regular grid (evenly spaced points in each dimension)
16 Regular,
17 /// Rectilinear grid (unevenly spaced points along each axis)
18 Rectilinear,
19 /// Unstructured grid (arbitrary point positions)
20 Unstructured,
21}
22
23/// Extrapolation mode for N-dimensional interpolation
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum ExtrapolateMode {
26 /// Return NaN for points outside the interpolation domain
27 Nan,
28 /// Raise an error for points outside the interpolation domain
29 Error,
30 /// Extrapolate based on the nearest edge points
31 Extrapolate,
32}
33
34/// N-dimensional interpolation object for rectilinear grids
35///
36/// This interpolator works with data defined on a rectilinear grid,
37/// where each dimension has its own set of coordinates.
38#[derive(Debug, Clone)]
39pub struct RegularGridInterpolator<F: Float + FromPrimitive + Debug + Display> {
40 /// Grid points in each dimension
41 points: Vec<Array1<F>>,
42 /// Values at grid points
43 values: Array<F, IxDyn>,
44 /// Method to use for interpolation
45 method: InterpolationMethod,
46 /// How to handle points outside the domain
47 extrapolate: ExtrapolateMode,
48}
49
50/// Available interpolation methods for N-dimensional interpolation
51#[derive(Debug, Clone, Copy, PartialEq)]
52pub enum InterpolationMethod {
53 /// Nearest neighbor interpolation
54 Nearest,
55 /// Linear interpolation
56 Linear,
57 /// Spline interpolation
58 Spline,
59}
60
61impl<F: crate::traits::InterpolationFloat> RegularGridInterpolator<F> {
62 /// Create a new RegularGridInterpolator
63 ///
64 /// # Arguments
65 ///
66 /// * `points` - A vector of arrays, where each array contains the points in one dimension
67 /// * `values` - An N-dimensional array of values at the grid points
68 /// * `method` - Interpolation method to use
69 /// * `extrapolate` - How to handle points outside the domain
70 ///
71 /// # Returns
72 ///
73 /// A new RegularGridInterpolator object
74 ///
75 /// # Errors
76 ///
77 /// * If points dimensions don't match values dimensions
78 /// * If any dimension has less than 2 points
79 ///
80 /// # Examples
81 ///
82 /// ```rust
83 /// use scirs2_core::ndarray::{Array, Array1, Dim, IxDyn};
84 /// use scirs2_interpolate::interpnd::{
85 /// RegularGridInterpolator, InterpolationMethod, ExtrapolateMode
86 /// };
87 ///
88 /// // Create a 3D grid
89 /// let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
90 /// let y = Array1::from_vec(vec![0.0, 1.0]);
91 /// let z = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
92 /// let points = vec![x, y, z];
93 ///
94 /// // Create values on the grid (3 × 2 × 4 = 24 values)
95 /// let mut values = Array::zeros(IxDyn(&[3, 2, 4]));
96 /// for i in 0..3 {
97 /// for j in 0..2 {
98 /// for k in 0..4 {
99 /// let idx = [i, j, k];
100 /// values[idx.as_slice()] = (i + j + k) as f64;
101 /// }
102 /// }
103 /// }
104 ///
105 /// let interpolator = RegularGridInterpolator::new(
106 /// points,
107 /// values,
108 /// InterpolationMethod::Linear,
109 /// ExtrapolateMode::Extrapolate,
110 /// ).unwrap();
111 /// ```
112 pub fn new(
113 points: Vec<Array1<F>>,
114 values: Array<F, IxDyn>,
115 method: InterpolationMethod,
116 extrapolate: ExtrapolateMode,
117 ) -> InterpolateResult<Self> {
118 // Check that points dimensions match values dimensions
119 if points.len() != values.ndim() {
120 return Err(InterpolateError::invalid_input(format!(
121 "Points dimensions ({}) do not match values dimensions ({})",
122 points.len(),
123 values.ndim()
124 )));
125 }
126
127 // Check that each dimension has at least 2 points
128 for (i, p) in points.iter().enumerate() {
129 if p.len() < 2 {
130 return Err(InterpolateError::invalid_input(format!(
131 "Dimension {} has less than 2 points",
132 i
133 )));
134 }
135
136 // Check that points are sorted
137 for j in 1..p.len() {
138 if p[j] <= p[j - 1] {
139 return Err(InterpolateError::invalid_input(format!(
140 "Points in dimension {} are not strictly increasing",
141 i
142 )));
143 }
144 }
145
146 // Check that values dimension matches points dimension
147 if p.len() != values.shape()[i] {
148 return Err(InterpolateError::invalid_input(format!(
149 "Values dimension {} size {} does not match points dimension size {}",
150 i,
151 values.shape()[i],
152 p.len()
153 )));
154 }
155 }
156
157 Ok(Self {
158 points,
159 values,
160 method,
161 extrapolate,
162 })
163 }
164
165 /// Interpolate at the given points
166 ///
167 /// # Arguments
168 ///
169 /// * `xi` - Array of points to interpolate at, shape (n_points, n_dims)
170 ///
171 /// # Returns
172 ///
173 /// Interpolated values at the given points, shape (n_points,)
174 ///
175 /// # Errors
176 ///
177 /// * If xi dimensions don't match grid dimensions
178 /// * If extrapolation is not allowed and points are outside the domain
179 ///
180 /// # Examples
181 ///
182 /// ```rust
183 /// use scirs2_core::ndarray::{Array, Array1, Array2, IxDyn};
184 /// use scirs2_interpolate::interpnd::{
185 /// RegularGridInterpolator, InterpolationMethod, ExtrapolateMode
186 /// };
187 ///
188 /// // Create a simple 2D grid
189 /// let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
190 /// let y = Array1::from_vec(vec![0.0, 1.0]);
191 /// let points = vec![x, y];
192 ///
193 /// let mut values = Array::zeros(IxDyn(&[3, 2]));
194 /// for i in 0..3 {
195 /// for j in 0..2 {
196 /// let idx = [i, j];
197 /// values[idx.as_slice()] = (i * i + j * j) as f64;
198 /// }
199 /// }
200 ///
201 /// let interpolator = RegularGridInterpolator::new(
202 /// points, values, InterpolationMethod::Linear, ExtrapolateMode::Extrapolate
203 /// ).unwrap();
204 ///
205 /// // Interpolate at multiple points
206 /// let xi = Array2::from_shape_vec((3, 2), vec![
207 /// 0.5, 0.5,
208 /// 1.0, 0.0,
209 /// 1.5, 0.5,
210 /// ]).unwrap();
211 ///
212 /// let results = interpolator.__call__(&xi.view()).unwrap();
213 /// assert_eq!(results.len(), 3);
214 /// ```
215 pub fn __call__(&self, xi: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
216 // Check that xi dimensions match grid dimensions
217 if xi.shape()[1] != self.points.len() {
218 return Err(InterpolateError::invalid_input(format!(
219 "Dimensions of interpolation points ({}) do not match grid dimensions ({})",
220 xi.shape()[1],
221 self.points.len()
222 )));
223 }
224
225 let n_points = xi.shape()[0];
226 let mut result = Array1::zeros(n_points);
227
228 for i in 0..n_points {
229 let point = xi.slice(scirs2_core::ndarray::s![i, ..]);
230 result[i] = self.interpolate_point(&point)?;
231 }
232
233 Ok(result)
234 }
235
236 /// Interpolate at a single point
237 ///
238 /// # Arguments
239 ///
240 /// * `point` - Coordinates of the point to interpolate at
241 ///
242 /// # Returns
243 ///
244 /// Interpolated value at the given point
245 fn interpolate_point(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
246 // Find the grid cells containing the point and calculate the normalized distances
247 let mut indices = Vec::with_capacity(self.points.len());
248 let mut weights = Vec::with_capacity(self.points.len());
249
250 for (dim, dim_points) in self.points.iter().enumerate() {
251 let x = point[dim];
252
253 // Check if point is outside the domain
254 if x < dim_points[0] || x > dim_points[dim_points.len() - 1] {
255 match self.extrapolate {
256 ExtrapolateMode::Error => {
257 return Err(InterpolateError::OutOfBounds(format!(
258 "Point outside domain in dimension {}: {} not in [{}, {}]",
259 dim,
260 x,
261 dim_points[0],
262 dim_points[dim_points.len() - 1]
263 )));
264 }
265 ExtrapolateMode::Nan => {
266 return Ok(F::nan());
267 }
268 // For extrapolate and constant modes, we'll find the nearest edge
269 _ => {}
270 }
271 }
272
273 // Find index of cell containing x
274 let idx = match self.method {
275 InterpolationMethod::Nearest => {
276 // For nearest, just find the closest point
277 let mut closest_idx = 0;
278 let mut min_dist = (x - dim_points[0]).abs();
279
280 for (j, &p) in dim_points.iter().enumerate().skip(1) {
281 let dist = (x - p).abs();
282 if dist < min_dist {
283 min_dist = dist;
284 closest_idx = j;
285 }
286 }
287
288 // Return just the index of the nearest point
289 indices.push(closest_idx);
290 weights.push(F::from_f64(1.0).unwrap());
291 continue;
292 }
293 _ => {
294 // For linear and spline, find the cell interval
295 let mut idx = dim_points.len() - 2;
296
297 // Find the cell that contains x (where x is between x[idx] and x[idx+1])
298 // Simply iterate through the points to find the right cell
299 let mut found = false;
300 for i in 0..dim_points.len() - 1 {
301 if x >= dim_points[i] && x <= dim_points[i + 1] {
302 idx = i;
303 found = true;
304 break;
305 }
306 }
307
308 // Handle extrapolation cases
309 if !found {
310 if x < dim_points[0] {
311 // Point is before the first grid point
312 if self.extrapolate == ExtrapolateMode::Extrapolate {
313 idx = 0;
314 } else if self.extrapolate == ExtrapolateMode::Error {
315 return Err(InterpolateError::out_of_domain(
316 x,
317 dim_points[0],
318 dim_points[dim_points.len() - 1],
319 "N-dimensional interpolation",
320 ));
321 } else {
322 // For Nan mode, clamp to boundary
323 idx = 0;
324 }
325 } else {
326 // Point is after the last grid point
327 idx = dim_points.len() - 2;
328 }
329 }
330
331 idx
332 }
333 };
334
335 // For linear interpolation, compute the weights
336 if self.method != InterpolationMethod::Nearest {
337 // Get the lower and upper bounds of the cell
338 let x0 = dim_points[idx];
339 let x1 = dim_points[idx + 1];
340
341 // Calculate the normalized distance for linear interpolation
342 // t is the fraction of the distance between x0 and x1
343 let t = if x1 == x0 {
344 F::from_f64(0.0).unwrap()
345 } else {
346 (x - x0) / (x1 - x0)
347 };
348
349 // Ensure t is between 0 and 1 (this handles any numerical precision issues)
350 let t = t
351 .max(F::from_f64(0.0).unwrap())
352 .min(F::from_f64(1.0).unwrap());
353
354 indices.push(idx);
355 weights.push(t);
356 }
357 }
358
359 // Perform the interpolation based on the method
360 match self.method {
361 InterpolationMethod::Nearest => {
362 // For nearest, we just return the value at the nearest grid point
363 let idx_array = indices.to_vec();
364 Ok(self.values[idx_array.as_slice()])
365 }
366 InterpolationMethod::Linear => {
367 // For linear, we need to compute a weighted average of the surrounding cell vertices
368 self.linear_interpolate(&indices, &weights)
369 }
370 InterpolationMethod::Spline => {
371 // For now, implement 2D spline interpolation only
372 if self.points.len() == 2 {
373 self.spline_interpolate_2d(point)
374 } else {
375 Err(InterpolateError::NotImplemented(format!(
376 "Spline interpolation only supports 2D grids, got {}D",
377 self.points.len()
378 )))
379 }
380 }
381 }
382 }
383
384 /// Perform linear interpolation
385 ///
386 /// # Arguments
387 ///
388 /// * `indices` - Indices of the cell containing the point
389 /// * `weights` - Normalized distances within the cell
390 ///
391 /// # Returns
392 ///
393 /// Interpolated value
394 fn linear_interpolate(&self, indices: &[usize], weights: &[F]) -> InterpolateResult<F> {
395 // For linear interpolation, we compute a weighted average of cell vertices
396 // Each vertex has a weight that is a product of 1D weights
397
398 // Handle the 2D case directly for better performance and correctness in test cases
399 if indices.len() == 2 {
400 // 2D case (rectangle)
401 let i0 = indices[0];
402 let i1 = indices[1];
403 let t0 = weights[0];
404 let t1 = weights[1];
405
406 // Get the values at the 4 corners
407 let idx00 = [i0, i1];
408 let idx01 = [i0, i1 + 1];
409 let idx10 = [i0 + 1, i1];
410 let idx11 = [i0 + 1, i1 + 1];
411
412 let v00 = self.values[idx00.as_slice()];
413 let v01 = self.values[idx01.as_slice()];
414 let v10 = self.values[idx10.as_slice()];
415 let v11 = self.values[idx11.as_slice()];
416
417 // Bilinear interpolation formula
418 // (1-t0)(1-t1)v00 + (1-t0)t1v01 + t0(1-t1)v10 + t0t1v11
419 let one = F::from_f64(1.0).unwrap();
420 let result = (one - t0) * (one - t1) * v00
421 + (one - t0) * t1 * v01
422 + t0 * (one - t1) * v10
423 + t0 * t1 * v11;
424
425 return Ok(result);
426 }
427
428 // General case for N dimensions
429 let n_dims = indices.len();
430 let mut result = F::from_f64(0.0).unwrap();
431
432 // We need to iterate through all 2^n_dims vertices of the hypercube
433 // Each vertex is identified by a binary pattern of lower/upper indices
434 let n_vertices = 1 << n_dims;
435
436 for vertex in 0..n_vertices {
437 // Build the index for this vertex and calculate its weight
438 let mut vertex_index = Vec::with_capacity(n_dims);
439 let mut vertex_weight = F::from_f64(1.0).unwrap();
440
441 for dim in 0..n_dims {
442 let use_upper = (vertex >> dim) & 1 == 1;
443 let idx = indices[dim] + if use_upper { 1 } else { 0 };
444 vertex_index.push(idx);
445
446 // Weight is either weight (for upper) or (1-weight) for lower
447 // For linear interpolation, weights represent normalized positions
448 // e.g., weight 0.7 means 70% toward upper point, 30% toward lower point
449 let dim_weight = if use_upper {
450 weights[dim]
451 } else {
452 F::from_f64(1.0).unwrap() - weights[dim]
453 };
454
455 vertex_weight *= dim_weight;
456 }
457
458 // Add the weighted value to the result
459 let vertex_value = self.values[vertex_index.as_slice()];
460 result += vertex_weight * vertex_value;
461 }
462
463 Ok(result)
464 }
465
466 /// Perform 2D spline interpolation
467 ///
468 /// # Arguments
469 ///
470 /// * `point` - Coordinates of the point to interpolate at
471 ///
472 /// # Returns
473 ///
474 /// Interpolated value at the point
475 fn spline_interpolate_2d(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
476 use crate::interp2d::{Interp2d, Interp2dKind};
477
478 if self.points.len() != 2 {
479 return Err(InterpolateError::invalid_input(
480 "spline_interpolate_2d requires exactly 2 dimensions",
481 ));
482 }
483
484 // Convert the N-D grid to 2D format
485 let x = &self.points[0];
486 let y = &self.points[1];
487
488 // The values should be in a 2D array format
489 let shape = self.values.shape();
490 if shape.len() != 2 {
491 return Err(InterpolateError::invalid_input(
492 "spline_interpolate_2d requires 2D value array",
493 ));
494 }
495
496 // Create 2D array from N-D array
497 let z = self
498 .values
499 .clone()
500 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
501 .map_err(|_| InterpolateError::invalid_input("Failed to convert to 2D array"))?;
502
503 // Create 2D interpolator
504 let interp = Interp2d::new(&x.view(), &y.view(), &z.view(), Interp2dKind::Cubic)?;
505
506 // Evaluate at the point
507 if point.len() != 2 {
508 return Err(InterpolateError::invalid_input(
509 "Point must have 2 coordinates for 2D spline interpolation",
510 ));
511 }
512
513 interp.evaluate(point[0], point[1])
514 }
515}
516
517/// N-dimensional interpolation on unstructured data (scattered points)
518///
519/// This interpolator works with data defined on scattered points without
520/// a regular grid structure, using various methods.
521#[derive(Debug, Clone)]
522#[allow(dead_code)]
523pub struct ScatteredInterpolator<F: Float + FromPrimitive + Debug + Display> {
524 /// Points coordinates, shape (n_points, n_dims)
525 points: Array2<F>,
526 /// Values at points, shape (n_points,)
527 values: Array1<F>,
528 /// Method to use for interpolation
529 method: ScatteredInterpolationMethod,
530 /// How to handle points outside the domain
531 extrapolate: ExtrapolateMode,
532 /// Additional parameters for specific methods
533 params: ScatteredInterpolatorParams<F>,
534}
535
536/// Parameters for scattered interpolation methods
537#[derive(Debug, Clone)]
538pub enum ScatteredInterpolatorParams<F: Float + FromPrimitive + Debug + Display> {
539 /// No additional parameters
540 None,
541 /// Parameters for IDW (Inverse Distance Weighting)
542 IDW {
543 /// Power parameter for IDW (default: 2.0)
544 power: F,
545 },
546 /// Parameters for RBF (Radial Basis Function)
547 RBF {
548 /// Epsilon parameter for RBF (default: 1.0)
549 epsilon: F,
550 /// Type of radial basis function
551 rbf_type: RBFType,
552 },
553}
554
555/// Types of radial basis functions
556#[derive(Debug, Clone, Copy, PartialEq)]
557pub enum RBFType {
558 /// Gaussian: exp(-(εr)²)
559 Gaussian,
560 /// Multiquadric: sqrt(1 + (εr)²)
561 Multiquadric,
562 /// Inverse multiquadric: 1/sqrt(1 + (εr)²)
563 InverseMultiquadric,
564 /// Thin plate spline: (εr)² log(εr)
565 ThinPlateSpline,
566}
567
568/// Available interpolation methods for scattered data
569#[derive(Debug, Clone, Copy, PartialEq)]
570pub enum ScatteredInterpolationMethod {
571 /// Nearest neighbor interpolation
572 Nearest,
573 /// Inverse Distance Weighting
574 IDW,
575 /// Radial Basis Function interpolation
576 RBF,
577}
578
579impl<
580 F: Float
581 + FromPrimitive
582 + Debug
583 + Display
584 + AddAssign
585 + SubAssign
586 + std::fmt::LowerExp
587 + std::ops::MulAssign
588 + std::ops::DivAssign
589 + Send
590 + Sync
591 + 'static,
592 > ScatteredInterpolator<F>
593{
594 /// Create a new ScatteredInterpolator
595 ///
596 /// # Arguments
597 ///
598 /// * `points` - Coordinates of sample points, shape (n_points, n_dims)
599 /// * `values` - Values at sample points, shape (n_points,)
600 /// * `method` - Interpolation method to use
601 /// * `extrapolate` - How to handle points outside the domain
602 /// * `params` - Additional parameters for specific methods
603 ///
604 /// # Returns
605 ///
606 /// A new ScatteredInterpolator object
607 ///
608 /// # Errors
609 ///
610 /// * If points and values dimensions don't match
611 ///
612 /// # Examples
613 ///
614 /// ```rust
615 /// use scirs2_core::ndarray::{Array1, Array2};
616 /// use scirs2_interpolate::interpnd::{
617 /// ScatteredInterpolator, ScatteredInterpolationMethod,
618 /// ExtrapolateMode, ScatteredInterpolatorParams
619 /// };
620 ///
621 /// // Create scattered 3D data
622 /// let points = Array2::from_shape_vec((5, 3), vec![
623 /// 0.0, 0.0, 0.0,
624 /// 1.0, 0.0, 0.0,
625 /// 0.0, 1.0, 0.0,
626 /// 0.0, 0.0, 1.0,
627 /// 0.5, 0.5, 0.5,
628 /// ]).unwrap();
629 /// let values = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0, 1.5]);
630 ///
631 /// // Create IDW interpolator with custom power
632 /// let interpolator = ScatteredInterpolator::new(
633 /// points,
634 /// values,
635 /// ScatteredInterpolationMethod::IDW,
636 /// ExtrapolateMode::Extrapolate,
637 /// Some(ScatteredInterpolatorParams::IDW { power: 3.0 }),
638 /// ).unwrap();
639 /// ```
640 pub fn new(
641 points: Array2<F>,
642 values: Array1<F>,
643 method: ScatteredInterpolationMethod,
644 extrapolate: ExtrapolateMode,
645 params: Option<ScatteredInterpolatorParams<F>>,
646 ) -> InterpolateResult<Self> {
647 // Check that points and values have compatible dimensions
648 if points.shape()[0] != values.len() {
649 return Err(InterpolateError::invalid_input(format!(
650 "Number of points ({}) does not match number of values ({})",
651 points.shape()[0],
652 values.len()
653 )));
654 }
655
656 // Set default parameters based on method if not provided
657 let params = match params {
658 Some(p) => p,
659 None => match method {
660 ScatteredInterpolationMethod::Nearest => ScatteredInterpolatorParams::None,
661 ScatteredInterpolationMethod::IDW => ScatteredInterpolatorParams::IDW {
662 power: F::from_f64(2.0).unwrap(),
663 },
664 ScatteredInterpolationMethod::RBF => ScatteredInterpolatorParams::RBF {
665 epsilon: F::from_f64(1.0).unwrap(),
666 rbf_type: RBFType::Multiquadric,
667 },
668 },
669 };
670
671 Ok(Self {
672 points,
673 values,
674 method,
675 extrapolate,
676 params,
677 })
678 }
679
680 /// Interpolate at the given points
681 ///
682 /// # Arguments
683 ///
684 /// * `xi` - Array of points to interpolate at, shape (n_points, n_dims)
685 ///
686 /// # Returns
687 ///
688 /// Interpolated values at the given points, shape (n_points,)
689 ///
690 /// # Errors
691 ///
692 /// * If xi dimensions don't match input dimensions
693 pub fn __call__(&self, xi: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
694 // Check that xi dimensions match input dimensions
695 if xi.shape()[1] != self.points.shape()[1] {
696 return Err(InterpolateError::invalid_input(format!(
697 "Dimensions of interpolation points ({}) do not match input dimensions ({})",
698 xi.shape()[1],
699 self.points.shape()[1]
700 )));
701 }
702
703 let n_points = xi.shape()[0];
704 let mut result = Array1::zeros(n_points);
705
706 for i in 0..n_points {
707 let point = xi.slice(scirs2_core::ndarray::s![i, ..]);
708 result[i] = self.interpolate_point(&point)?;
709 }
710
711 Ok(result)
712 }
713
714 /// Interpolate at a single point
715 ///
716 /// # Arguments
717 ///
718 /// * `point` - Coordinates of the point to interpolate at
719 ///
720 /// # Returns
721 ///
722 /// Interpolated value at the given point
723 fn interpolate_point(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
724 match self.method {
725 ScatteredInterpolationMethod::Nearest => self.nearest_interpolate(point),
726 ScatteredInterpolationMethod::IDW => self.idw_interpolate(point),
727 ScatteredInterpolationMethod::RBF => self.rbf_interpolate(point),
728 }
729 }
730
731 /// Perform nearest neighbor interpolation
732 ///
733 /// # Arguments
734 ///
735 /// * `point` - Coordinates of the point to interpolate at
736 ///
737 /// # Returns
738 ///
739 /// Interpolated value at the given point
740 fn nearest_interpolate(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
741 let mut min_dist = F::infinity();
742 let mut nearest_idx = 0;
743
744 // Find the nearest point
745 for i in 0..self.points.shape()[0] {
746 let p = self.points.slice(scirs2_core::ndarray::s![i, ..]);
747 let dist = self.compute_distance(&p, point);
748
749 if dist < min_dist {
750 min_dist = dist;
751 nearest_idx = i;
752 }
753 }
754
755 Ok(self.values[nearest_idx])
756 }
757
758 /// Perform Inverse Distance Weighting interpolation
759 ///
760 /// # Arguments
761 ///
762 /// * `point` - Coordinates of the point to interpolate at
763 ///
764 /// # Returns
765 ///
766 /// Interpolated value at the given point
767 fn idw_interpolate(&self, point: &ArrayView1<F>) -> InterpolateResult<F> {
768 // Get the power parameter
769 let power = match self.params {
770 ScatteredInterpolatorParams::IDW { power } => power,
771 _ => F::from_f64(2.0).unwrap(), // Default to 2.0 if wrong params
772 };
773
774 let mut sum_weights = F::from_f64(0.0).unwrap();
775 let mut sum_weighted_values = F::from_f64(0.0).unwrap();
776
777 // Check for exact match with any input point
778 for i in 0..self.points.shape()[0] {
779 let p = self.points.slice(scirs2_core::ndarray::s![i, ..]);
780 let dist = self.compute_distance(&p, point);
781
782 if dist.is_zero() {
783 // Exact match found
784 return Ok(self.values[i]);
785 }
786
787 // Calculate weight as 1/distance^power
788 let weight = F::from_f64(1.0).unwrap() / dist.powf(power);
789 sum_weights += weight;
790 sum_weighted_values += weight * self.values[i];
791 }
792
793 // Calculate weighted average
794 if sum_weights.is_zero() {
795 // This should not happen with non-zero distances
796 return Err(InterpolateError::ComputationError(
797 "Sum of weights is zero in IDW interpolation".to_string(),
798 ));
799 }
800
801 Ok(sum_weighted_values / sum_weights)
802 }
803
804 /// Compute Euclidean distance between two points
805 ///
806 /// # Arguments
807 ///
808 /// * `p1` - First point
809 /// * `p2` - Second point
810 ///
811 /// # Returns
812 ///
813 /// Euclidean distance between the points
814 fn compute_distance(&self, p1: &ArrayView1<F>, p2: &ArrayView1<F>) -> F {
815 let mut sum_sq = F::from_f64(0.0).unwrap();
816 for i in 0..p1.len() {
817 let diff = p1[i] - p2[i];
818 sum_sq += diff * diff;
819 }
820 sum_sq.sqrt()
821 }
822
823 /// Perform RBF interpolation at a point
824 ///
825 /// # Arguments
826 ///
827 /// * `point` - Coordinates of the point to interpolate at
828 ///
829 /// # Returns
830 ///
831 /// Interpolated value at the point
832 fn rbf_interpolate(&self, point: &ArrayView1<F>) -> InterpolateResult<F>
833 where
834 F: Float
835 + FromPrimitive
836 + Debug
837 + Display
838 + AddAssign
839 + std::ops::SubAssign
840 + std::fmt::LowerExp
841 + std::ops::MulAssign
842 + std::ops::DivAssign
843 + Send
844 + Sync
845 + 'static,
846 {
847 // Create RBF interpolator
848 let epsilon = F::from_f64(1.0).unwrap(); // Default shape parameter
849 let rbf = RBFInterpolator::new(
850 &self.points.view(),
851 &self.values.view(),
852 RBFKernel::Gaussian,
853 epsilon,
854 )?;
855
856 // Evaluate at the query point (reshape 1D point to 2D for RBF interface)
857 let binding = point.to_owned();
858 let point_2d = binding.to_shape((1, point.len())).unwrap();
859 let result = rbf.evaluate(&point_2d.view())?;
860 Ok(result[0])
861 }
862}
863
864/// Create an N-dimensional interpolator on a regular grid
865///
866/// # Arguments
867///
868/// * `points` - A vector of arrays, where each array contains the points in one dimension
869/// * `values` - An N-dimensional array of values at the grid points
870/// * `method` - Interpolation method to use
871/// * `extrapolate` - How to handle points outside the domain
872///
873/// # Returns
874///
875/// A new RegularGridInterpolator object
876///
877/// # Errors
878///
879/// * If points dimensions don't match values dimensions
880/// * If any dimension has less than 2 points
881///
882/// # Examples
883///
884/// ```
885/// use scirs2_core::ndarray::{Array, Array1, Dim, IxDyn};
886/// use scirs2_core::numeric::Float;
887/// use scirs2_interpolate::interpnd::{
888/// make_interp_nd, InterpolationMethod, ExtrapolateMode
889/// };
890///
891/// // Create a 2D grid
892/// let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
893/// let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
894/// let points = vec![x, y];
895///
896/// // Create values on the grid (z = x^2 + y^2)
897/// let mut values = Array::zeros(IxDyn(&[3, 4]));
898/// for i in 0..3 {
899/// for j in 0..4 {
900/// let idx = [i, j];
901/// values[idx.as_slice()] = (i * i + j * j) as f64;
902/// }
903/// }
904///
905/// // Create the interpolator
906/// let interp = make_interp_nd(
907/// points,
908/// values,
909/// InterpolationMethod::Linear,
910/// ExtrapolateMode::Extrapolate,
911/// ).unwrap();
912///
913/// // Interpolate at a point
914/// use scirs2_core::ndarray::Array2;
915/// let points_to_interp = Array2::from_shape_vec((1, 2), vec![1.5, 2.5]).unwrap();
916/// let result = interp.__call__(&points_to_interp.view()).unwrap();
917/// assert!((result[0] - 9.0).abs() < 1e-10);
918/// ```
919#[allow(dead_code)]
920pub fn make_interp_nd<F: crate::traits::InterpolationFloat>(
921 points: Vec<Array1<F>>,
922 values: Array<F, IxDyn>,
923 method: InterpolationMethod,
924 extrapolate: ExtrapolateMode,
925) -> InterpolateResult<RegularGridInterpolator<F>> {
926 RegularGridInterpolator::new(points, values, method, extrapolate)
927}
928
929/// Create an N-dimensional interpolator for scattered data
930///
931/// # Arguments
932///
933/// * `points` - Coordinates of sample points, shape (n_points, n_dims)
934/// * `values` - Values at sample points, shape (n_points,)
935/// * `method` - Interpolation method to use
936/// * `extrapolate` - How to handle points outside the domain
937/// * `params` - Additional parameters for specific methods
938///
939/// # Returns
940///
941/// A new ScatteredInterpolator object
942///
943/// # Errors
944///
945/// * If points and values dimensions don't match
946#[allow(dead_code)]
947pub fn make_interp_scattered<F: crate::traits::InterpolationFloat>(
948 points: Array2<F>,
949 values: Array1<F>,
950 method: ScatteredInterpolationMethod,
951 extrapolate: ExtrapolateMode,
952 params: Option<ScatteredInterpolatorParams<F>>,
953) -> InterpolateResult<ScatteredInterpolator<F>> {
954 ScatteredInterpolator::new(points, values, method, extrapolate, params)
955}
956
957/// Map values on a rectilinear grid to a new grid
958///
959/// # Arguments
960///
961/// * `old_grid` - Vec of Arrays representing the old grid points in each dimension
962/// * `old_values` - Values at old grid points
963/// * `new_grid` - Vec of Arrays representing the new grid points in each dimension
964/// * `method` - Interpolation method to use
965///
966/// # Returns
967///
968/// Values at new grid points
969///
970/// # Errors
971///
972/// * If dimensions don't match
973/// * If any dimension has less than 2 points
974#[allow(dead_code)]
975pub fn map_coordinates<F: crate::traits::InterpolationFloat>(
976 old_grid: Vec<Array1<F>>,
977 old_values: Array<F, IxDyn>,
978 new_grid: Vec<Array1<F>>,
979 method: InterpolationMethod,
980) -> InterpolateResult<Array<F, IxDyn>> {
981 // Create the interpolator
982 let interp =
983 RegularGridInterpolator::new(old_grid, old_values, method, ExtrapolateMode::Error)?;
984
985 // Determine the shape of the output array
986 let outshape: Vec<usize> = new_grid.iter().map(|x| x.len()).collect();
987 let n_dims = outshape.len();
988
989 // Create meshgrid of coordinates
990 let mut indices = vec![Vec::<F>::new(); n_dims];
991 let mut shape = vec![1; n_dims];
992
993 for (i, grid) in new_grid.iter().enumerate() {
994 let mut idx = vec![F::from_f64(0.0).unwrap(); grid.len()];
995 for (j, val) in grid.iter().enumerate() {
996 idx[j] = *val;
997 }
998 indices[i] = idx;
999 shape[i] = grid.len();
1000 }
1001
1002 // Calculate total number of points
1003 let total_points: usize = shape.iter().product();
1004
1005 // Create the output array
1006 let mut out_values = Array::zeros(IxDyn(&outshape));
1007
1008 // Create a 2D array of all points to interpolate
1009 let mut points = Array2::zeros((total_points, n_dims));
1010
1011 // Create a multi-index for traversing the _grid
1012 let mut multi_index = vec![0; n_dims];
1013
1014 for flat_idx in 0..total_points {
1015 // Convert flat index to multi-index
1016 let mut temp = flat_idx;
1017 for i in (0..n_dims).rev() {
1018 multi_index[i] = temp % shape[i];
1019 temp /= shape[i];
1020 }
1021
1022 // Set point coordinates
1023 for i in 0..n_dims {
1024 points[[flat_idx, i]] = indices[i][multi_index[i]];
1025 }
1026 }
1027
1028 // Perform interpolation for all points
1029 let values = interp.__call__(&points.view())?;
1030
1031 // Reshape the result to match the output _grid
1032 let mut out_idx_vec = Vec::with_capacity(n_dims);
1033 for flat_idx in 0..total_points {
1034 // Convert flat index to multi-index
1035 let mut temp = flat_idx;
1036 for i in (0..n_dims).rev() {
1037 multi_index[i] = temp % shape[i];
1038 temp /= shape[i];
1039 }
1040
1041 // Convert multi-index to output index vector
1042 out_idx_vec.clear();
1043 out_idx_vec.extend_from_slice(&multi_index[..n_dims]);
1044
1045 // Set the value in the output array
1046 *out_values.get_mut(out_idx_vec.as_slice()).unwrap() = values[flat_idx];
1047 }
1048
1049 Ok(out_values)
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054 use super::*;
1055 use approx::assert_abs_diff_eq;
1056 use scirs2_core::ndarray::{Array2, IxDyn}; // 配列操作用
1057
1058 #[test]
1059 fn test_regular_grid_interpolator_2d() {
1060 // Create a 2D grid
1061 let x = Array1::from_vec(vec![0.0, 1.0, 2.0]);
1062 let y = Array1::from_vec(vec![0.0, 1.0, 2.0, 3.0]);
1063 let points = vec![x, y];
1064
1065 // Create values on the grid (z = x^2 + y^2)
1066 let mut values = Array::zeros(IxDyn(&[3, 4]));
1067 for i in 0..3 {
1068 for j in 0..4 {
1069 let idx = [i, j];
1070 values[idx.as_slice()] = (i * i + j * j) as f64;
1071 }
1072 }
1073
1074 // Create the interpolator
1075 let interp = RegularGridInterpolator::new(
1076 points.clone(),
1077 values.clone(),
1078 InterpolationMethod::Linear,
1079 ExtrapolateMode::Extrapolate,
1080 )
1081 .unwrap();
1082
1083 // Test interpolation at grid points
1084 let grid_point = Array2::from_shape_vec((1, 2), vec![1.0, 2.0]).unwrap();
1085 let result = interp.__call__(&grid_point.view()).unwrap();
1086 assert_abs_diff_eq!(result[0], 5.0, epsilon = 1e-10);
1087
1088 // Test interpolation at non-grid points
1089 let non_grid_point = Array2::from_shape_vec((1, 2), vec![1.5, 2.5]).unwrap();
1090 let result = interp.__call__(&non_grid_point.view()).unwrap();
1091
1092 // For point (1.5, 2.5):
1093 // We're interpolating between grid points:
1094 // (1,2) -> value = 5.0
1095 // (1,3) -> value = 10.0
1096 // (2,2) -> value = 8.0
1097 // (2,3) -> value = 13.0
1098 // With weights: x=0.5, y=0.5
1099 // Expected = (1-0.5)(1-0.5)*5.0 + (1-0.5)(0.5)*10.0 + (0.5)(1-0.5)*8.0 + (0.5)(0.5)*13.0
1100 // = 0.25*5.0 + 0.25*10.0 + 0.25*8.0 + 0.25*13.0
1101 // = 1.25 + 2.5 + 2.0 + 3.25 = 9.0
1102 assert_abs_diff_eq!(result[0], 9.0, epsilon = 1e-10);
1103
1104 // Test multiple points at once
1105 let multiple_points = Array2::from_shape_vec((2, 2), vec![1.0, 1.0, 2.0, 2.0]).unwrap();
1106 let result = interp.__call__(&multiple_points.view()).unwrap();
1107 assert_abs_diff_eq!(result[0], 2.0, epsilon = 1e-10);
1108 assert_abs_diff_eq!(result[1], 8.0, epsilon = 1e-10);
1109
1110 // Test nearest neighbor interpolation
1111 let interp_nearest = RegularGridInterpolator::new(
1112 points.clone(),
1113 values.clone(),
1114 InterpolationMethod::Nearest,
1115 ExtrapolateMode::Extrapolate,
1116 )
1117 .unwrap();
1118
1119 let point = Array2::from_shape_vec((1, 2), vec![1.6, 1.7]).unwrap();
1120 let result = interp_nearest.__call__(&point.view()).unwrap();
1121 // Point (1.6, 1.7) is closest to grid point (2,2) which has value 8.0
1122 assert_abs_diff_eq!(result[0], 8.0, epsilon = 1e-10);
1123 }
1124
1125 #[test]
1126 fn test_scattered_interpolator() {
1127 // Create scattered points in 2D
1128 let points = Array2::from_shape_vec(
1129 (5, 2),
1130 vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
1131 )
1132 .unwrap();
1133
1134 // Create values at those points (z = x^2 + y^2)
1135 let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 0.5]);
1136
1137 // Create the interpolator with IDW
1138 let interp = ScatteredInterpolator::new(
1139 points.clone(),
1140 values.clone(),
1141 ScatteredInterpolationMethod::IDW,
1142 ExtrapolateMode::Extrapolate,
1143 Some(ScatteredInterpolatorParams::IDW { power: 2.0 }),
1144 )
1145 .unwrap();
1146
1147 // Test interpolation at a point
1148 let test_point = Array2::from_shape_vec((1, 2), vec![0.5, 0.0]).unwrap();
1149 let result = interp.__call__(&test_point.view()).unwrap();
1150 // Value should be between 0.0 and 1.0, closer to 0.5
1151 assert!(result[0] > 0.0 && result[0] < 1.0);
1152
1153 // Test nearest neighbor interpolator
1154 let interp_nearest = ScatteredInterpolator::new(
1155 points,
1156 values,
1157 ScatteredInterpolationMethod::Nearest,
1158 ExtrapolateMode::Extrapolate,
1159 None,
1160 )
1161 .unwrap();
1162
1163 let test_point = Array2::from_shape_vec((1, 2), vec![0.6, 0.6]).unwrap();
1164 let result = interp_nearest.__call__(&test_point.view()).unwrap();
1165 assert_abs_diff_eq!(result[0], 0.5, epsilon = 1e-10); // Should pick the center point
1166 }
1167}