Skip to main content

scirs2_spatial/interpolate/
idw.rs

1//! Inverse Distance Weighting interpolation
2//!
3//! This module provides Inverse Distance Weighting (IDW) interpolation, a
4//! simple and efficient method for interpolating scattered data.
5//!
6//! IDW interpolation works by weighting neighboring points by the inverse of
7//! their distance raised to a power. The power parameter controls the smoothness
8//! of the interpolation, with higher values giving more weight to close points.
9//!
10//! The method is fast but can produce "bull's-eye" patterns around sample points,
11//! especially with high power values.
12
13use crate::distance::EuclideanDistance;
14use crate::error::{SpatialError, SpatialResult};
15use crate::kdtree::KDTree;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17
18/// Inverse Distance Weighting interpolator for scattered data
19///
20/// # Examples
21///
22/// ```
23/// use scirs2_spatial::interpolate::IDWInterpolator;
24/// use scirs2_core::ndarray::array;
25///
26/// // Create sample points and values
27/// let points = array![
28///     [0.0, 0.0],
29///     [1.0, 0.0],
30///     [0.0, 1.0],
31///     [1.0, 1.0],
32/// ];
33/// let values = array![0.0, 1.0, 2.0, 3.0];
34///
35/// // Create interpolator with power=2
36/// let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None).expect("Operation failed");
37///
38/// // Interpolate at a point
39/// let query_point = array![0.5, 0.5];
40/// let result = interp.interpolate(&query_point.view()).expect("Operation failed");
41///
42/// // Should be close to 1.5
43/// assert!((result - 1.5).abs() < 0.1);
44/// ```
45#[derive(Debug, Clone)]
46pub struct IDWInterpolator {
47    /// Input points (N x D)
48    points: Array2<f64>,
49    /// Input values (N)
50    values: Array1<f64>,
51    /// Dimensionality of the input points
52    dim: usize,
53    /// Number of input points
54    n_points: usize,
55    /// Power parameter (p)
56    power: f64,
57    /// Number of neighbors to use (None means use all points)
58    n_neighbors: Option<usize>,
59    /// KD-tree for fast nearest neighbor lookup
60    kdtree: KDTree<f64, EuclideanDistance<f64>>,
61}
62
63impl IDWInterpolator {
64    /// Create a new IDW interpolator
65    ///
66    /// # Arguments
67    ///
68    /// * `points` - Input points with shape (n_samples, n_dims)
69    /// * `values` - Input values with shape (n_samples,)
70    /// * `power` - Power parameter (p), controls the importance of nearby points
71    /// * `n_neighbors` - Number of neighbors to use (None = use all points)
72    ///
73    /// # Returns
74    ///
75    /// A new IDWInterpolator
76    ///
77    /// # Errors
78    ///
79    /// * If points and values have different lengths
80    /// * If power is negative
81    /// * If n_neighbors is 0 or greater than n_points
82    pub fn new(
83        points: &ArrayView2<'_, f64>,
84        values: &ArrayView1<f64>,
85        power: f64,
86        n_neighbors: Option<usize>,
87    ) -> SpatialResult<Self> {
88        // Check input dimensions
89        let n_points = points.nrows();
90        let dim = points.ncols();
91
92        if n_points != values.len() {
93            return Err(SpatialError::DimensionError(format!(
94                "Number of points ({}) must match number of values ({})",
95                n_points,
96                values.len()
97            )));
98        }
99
100        if power < 0.0 {
101            return Err(SpatialError::ValueError(format!(
102                "Power parameter must be non-negative, got {power}"
103            )));
104        }
105
106        if let Some(k) = n_neighbors {
107            if k == 0 {
108                return Err(SpatialError::ValueError(
109                    "Number of _neighbors must be positive".to_string(),
110                ));
111            }
112            if k > n_points {
113                return Err(SpatialError::ValueError(format!(
114                    "Number of _neighbors ({k}) cannot exceed number of points ({n_points})"
115                )));
116            }
117        }
118
119        // Build KD-tree for fast nearest neighbor lookups
120        let kdtree = KDTree::new(&points.to_owned())?;
121
122        Ok(Self {
123            points: points.to_owned(),
124            values: values.to_owned(),
125            dim,
126            n_points,
127            power,
128            n_neighbors,
129            kdtree,
130        })
131    }
132
133    /// Interpolate at a single point
134    ///
135    /// # Arguments
136    ///
137    /// * `point` - Query point with shape (n_dims,)
138    ///
139    /// # Returns
140    ///
141    /// Interpolated value at the query point
142    ///
143    /// # Errors
144    ///
145    /// * If the point dimensions don't match the interpolator
146    pub fn interpolate(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
147        // Check dimension
148        if point.len() != self.dim {
149            return Err(SpatialError::DimensionError(format!(
150                "Query point has dimension {}, expected {}",
151                point.len(),
152                self.dim
153            )));
154        }
155
156        // Handle exact matches first
157        for i in 0..self.n_points {
158            let data_point = self.points.row(i);
159            if Self::is_same_point(&data_point, point) {
160                return Ok(self.values[i]);
161            }
162        }
163
164        // Get the neighbors to use
165        let (indices, distances) = match self.n_neighbors {
166            Some(k) => {
167                // Use k nearest neighbors
168                self.kdtree
169                    .query(point.as_slice().expect("Operation failed"), k)?
170            }
171            None => {
172                // Use all points
173                let mut indices = Vec::with_capacity(self.n_points);
174                let mut distances = Vec::with_capacity(self.n_points);
175
176                for i in 0..self.n_points {
177                    let data_point = self.points.row(i);
178                    let dist_sq = Self::squared_distance(&data_point, point);
179                    indices.push(i);
180                    distances.push(dist_sq);
181                }
182
183                (indices, distances)
184            }
185        };
186
187        // Calculate IDW weights and interpolated value
188        let mut weighted_sum = 0.0;
189        let mut weight_sum = 0.0;
190
191        for i in 0..indices.len() {
192            let dist_sq = distances[i];
193
194            // Handle zero distance case (coincident point)
195            if dist_sq < 1e-10 {
196                return Ok(self.values[indices[i]]);
197            }
198
199            // Calculate weight
200            let weight = 1.0 / dist_sq.powf(self.power / 2.0);
201
202            weighted_sum += weight * self.values[indices[i]];
203            weight_sum += weight;
204        }
205
206        if weight_sum > 0.0 {
207            Ok(weighted_sum / weight_sum)
208        } else {
209            // This should not happen with valid data
210            Err(SpatialError::ComputationError(
211                "Zero weight sum in IDW interpolation".to_string(),
212            ))
213        }
214    }
215
216    /// Interpolate at multiple points
217    ///
218    /// # Arguments
219    ///
220    /// * `points` - Query points with shape (n_queries, n_dims)
221    ///
222    /// # Returns
223    ///
224    /// Interpolated values with shape (n_queries,)
225    ///
226    /// # Errors
227    ///
228    /// * If the points dimensions don't match the interpolator
229    pub fn interpolate_many(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
230        // Check dimensions
231        if points.ncols() != self.dim {
232            return Err(SpatialError::DimensionError(format!(
233                "Query _points have dimension {}, expected {}",
234                points.ncols(),
235                self.dim
236            )));
237        }
238
239        let n_queries = points.nrows();
240        let mut results = Array1::zeros(n_queries);
241
242        // Interpolate each point
243        for i in 0..n_queries {
244            let point = points.row(i);
245            results[i] = self.interpolate(&point)?;
246        }
247
248        Ok(results)
249    }
250
251    /// Get the power parameter used by this interpolator
252    pub fn power(&self) -> f64 {
253        self.power
254    }
255
256    /// Get the number of neighbors used by this interpolator
257    pub fn n_neighbors(&self) -> Option<usize> {
258        self.n_neighbors
259    }
260
261    /// Set the power parameter
262    ///
263    /// # Arguments
264    ///
265    /// * `power` - New power parameter
266    ///
267    /// # Errors
268    ///
269    /// * If power is negative
270    pub fn set_power(&mut self, power: f64) -> SpatialResult<()> {
271        if power < 0.0 {
272            return Err(SpatialError::ValueError(format!(
273                "Power parameter must be non-negative, got {power}"
274            )));
275        }
276
277        self.power = power;
278        Ok(())
279    }
280
281    /// Set the number of neighbors
282    ///
283    /// # Arguments
284    ///
285    /// * `n_neighbors` - New number of neighbors (None = use all points)
286    ///
287    /// # Errors
288    ///
289    /// * If n_neighbors is 0 or greater than n_points
290    pub fn set_n_neighbors(&mut self, _nneighbors: Option<usize>) -> SpatialResult<()> {
291        if let Some(k) = _nneighbors {
292            if k == 0 {
293                return Err(SpatialError::ValueError(
294                    "Number of _neighbors must be positive".to_string(),
295                ));
296            }
297            if k > self.n_points {
298                return Err(SpatialError::ValueError(format!(
299                    "Number of _neighbors ({}) cannot exceed number of points ({})",
300                    k, self.n_points
301                )));
302            }
303        }
304
305        self.n_neighbors = _nneighbors;
306        Ok(())
307    }
308
309    /// Check if two points are the same (within a small tolerance)
310    ///
311    /// # Arguments
312    ///
313    /// * `p1` - First point
314    /// * `p2` - Second point
315    ///
316    /// # Returns
317    ///
318    /// True if the points are considered the same
319    fn is_same_point(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> bool {
320        Self::squared_distance(p1, p2) < 1e-10
321    }
322
323    /// Compute the squared Euclidean distance between two points
324    ///
325    /// # Arguments
326    ///
327    /// * `p1` - First point
328    /// * `p2` - Second point
329    ///
330    /// # Returns
331    ///
332    /// Squared Euclidean distance between the points
333    fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
334        let mut sum_sq = 0.0;
335        for i in 0..p1.len().min(p2.len()) {
336            let diff = p1[i] - p2[i];
337            sum_sq += diff * diff;
338        }
339        sum_sq
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use approx::assert_relative_eq;
347    use scirs2_core::ndarray::array;
348
349    #[test]
350    fn test_idw_interpolation_basic() {
351        // Create a simple grid of points
352        let points = array![
353            [0.0, 0.0], // 0: bottom-left
354            [1.0, 0.0], // 1: bottom-right
355            [0.0, 1.0], // 2: top-left
356            [1.0, 1.0], // 3: top-right
357        ];
358
359        // Set up a simple function z = x + y
360        let values = array![0.0, 1.0, 1.0, 2.0];
361
362        // Test with different power values
363        for power in &[1.0, 2.0, 3.0] {
364            // Create the interpolator
365            let interp = IDWInterpolator::new(&points.view(), &values.view(), *power, None)
366                .expect("Operation failed");
367
368            // Test at the data points (should interpolate exactly)
369            let val_00 = interp
370                .interpolate(&array![0.0, 0.0].view())
371                .expect("Operation failed");
372            let val_10 = interp
373                .interpolate(&array![1.0, 0.0].view())
374                .expect("Operation failed");
375            let val_01 = interp
376                .interpolate(&array![0.0, 1.0].view())
377                .expect("Operation failed");
378            let val_11 = interp
379                .interpolate(&array![1.0, 1.0].view())
380                .expect("Operation failed");
381
382            assert_relative_eq!(val_00, 0.0, epsilon = 1e-10);
383            assert_relative_eq!(val_10, 1.0, epsilon = 1e-10);
384            assert_relative_eq!(val_01, 1.0, epsilon = 1e-10);
385            assert_relative_eq!(val_11, 2.0, epsilon = 1e-10);
386
387            // Test at the center
388            let val_center = interp
389                .interpolate(&array![0.5, 0.5].view())
390                .expect("Operation failed");
391            assert_relative_eq!(val_center, 1.0, epsilon = 0.1);
392        }
393    }
394
395    #[test]
396    fn test_idw_with_neighbors() {
397        // Create a more complex set of points
398        let points = array![
399            [0.0, 0.0], // 0
400            [1.0, 0.0], // 1
401            [0.0, 1.0], // 2
402            [1.0, 1.0], // 3
403            [0.5, 0.5], // 4
404            [0.2, 0.8], // 5
405            [0.8, 0.2], // 6
406            [0.3, 0.3], // 7
407            [0.7, 0.7], // 8
408        ];
409
410        // Function z = x + y
411        let values = Array1::from_vec(
412            points
413                .rows()
414                .into_iter()
415                .map(|row| row[0] + row[1])
416                .collect(),
417        );
418
419        // Create interpolator with different numbers of neighbors
420        let interp_all = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
421            .expect("Operation failed");
422
423        let interp_3 = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(3))
424            .expect("Operation failed");
425
426        // Test at a new point
427        let test_point = array![0.6, 0.4];
428
429        let val_all = interp_all
430            .interpolate(&test_point.view())
431            .expect("Operation failed");
432        let val_3 = interp_3
433            .interpolate(&test_point.view())
434            .expect("Operation failed");
435
436        // Both should be close to x + y = 0.6 + 0.4 = 1.0
437        assert_relative_eq!(val_all, 1.0, epsilon = 0.1);
438        assert_relative_eq!(val_3, 1.0, epsilon = 0.1);
439
440        // They might be slightly different, but not guaranteed in all implementations
441        // Different implementations may produce very similar results
442        // assert!(f64::abs(val_all - val_3) > 1e-6);
443    }
444
445    #[test]
446    fn test_interpolate_many() {
447        // Create a simple grid of points
448        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
449
450        // Set up a simple function z = x + y
451        let values = array![0.0, 1.0, 1.0, 2.0];
452
453        // Create the interpolator
454        let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
455            .expect("Operation failed");
456
457        // Test multiple points at once
458        let query_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
459
460        let results = interp
461            .interpolate_many(&query_points.view())
462            .expect("Operation failed");
463
464        assert_eq!(results.len(), 5);
465        assert_relative_eq!(results[0], 0.0, epsilon = 1e-10);
466        assert_relative_eq!(results[1], 1.0, epsilon = 1e-10);
467        assert_relative_eq!(results[2], 1.0, epsilon = 1e-10);
468        assert_relative_eq!(results[3], 2.0, epsilon = 1e-10);
469        assert_relative_eq!(results[4], 1.0, epsilon = 0.1);
470    }
471
472    #[test]
473    fn test_setter_methods() {
474        // Create interpolator
475        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
476
477        let values = array![0.0, 1.0, 1.0, 2.0];
478
479        let mut interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
480            .expect("Operation failed");
481
482        // Test setter methods
483        assert_eq!(interp.power(), 2.0);
484        assert_eq!(interp.n_neighbors(), None);
485
486        interp.set_power(3.0).expect("Operation failed");
487        assert_eq!(interp.power(), 3.0);
488
489        interp.set_n_neighbors(Some(2)).expect("Operation failed");
490        assert_eq!(interp.n_neighbors(), Some(2));
491
492        // Test error cases
493        let result = interp.set_power(-1.0);
494        assert!(result.is_err());
495
496        let result = interp.set_n_neighbors(Some(0));
497        assert!(result.is_err());
498
499        let result = interp.set_n_neighbors(Some(10));
500        assert!(result.is_err());
501    }
502
503    #[test]
504    fn test_error_handling() {
505        // Wrong dimensions
506        let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
507        let values = array![0.0, 1.0, 1.0];
508
509        let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
510            .expect("Operation failed");
511
512        let result = interp.interpolate(&array![0.0].view());
513        assert!(result.is_err());
514
515        // Negative power
516        let result = IDWInterpolator::new(&points.view(), &values.view(), -1.0, None);
517        assert!(result.is_err());
518
519        // Invalid neighbors
520        let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(0));
521        assert!(result.is_err());
522
523        let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(10));
524        assert!(result.is_err());
525    }
526}