1use crate::distance::EuclideanDistance;
14use crate::error::{SpatialError, SpatialResult};
15use crate::kdtree::KDTree;
16use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
17
18#[derive(Debug, Clone)]
46pub struct IDWInterpolator {
47 points: Array2<f64>,
49 values: Array1<f64>,
51 dim: usize,
53 n_points: usize,
55 power: f64,
57 n_neighbors: Option<usize>,
59 kdtree: KDTree<f64, EuclideanDistance<f64>>,
61}
62
63impl IDWInterpolator {
64 pub fn new(
83 points: &ArrayView2<'_, f64>,
84 values: &ArrayView1<f64>,
85 power: f64,
86 n_neighbors: Option<usize>,
87 ) -> SpatialResult<Self> {
88 let n_points = points.nrows();
90 let dim = points.ncols();
91
92 if n_points != values.len() {
93 return Err(SpatialError::DimensionError(format!(
94 "Number of points ({}) must match number of values ({})",
95 n_points,
96 values.len()
97 )));
98 }
99
100 if power < 0.0 {
101 return Err(SpatialError::ValueError(format!(
102 "Power parameter must be non-negative, got {power}"
103 )));
104 }
105
106 if let Some(k) = n_neighbors {
107 if k == 0 {
108 return Err(SpatialError::ValueError(
109 "Number of _neighbors must be positive".to_string(),
110 ));
111 }
112 if k > n_points {
113 return Err(SpatialError::ValueError(format!(
114 "Number of _neighbors ({k}) cannot exceed number of points ({n_points})"
115 )));
116 }
117 }
118
119 let kdtree = KDTree::new(&points.to_owned())?;
121
122 Ok(Self {
123 points: points.to_owned(),
124 values: values.to_owned(),
125 dim,
126 n_points,
127 power,
128 n_neighbors,
129 kdtree,
130 })
131 }
132
133 pub fn interpolate(&self, point: &ArrayView1<f64>) -> SpatialResult<f64> {
147 if point.len() != self.dim {
149 return Err(SpatialError::DimensionError(format!(
150 "Query point has dimension {}, expected {}",
151 point.len(),
152 self.dim
153 )));
154 }
155
156 for i in 0..self.n_points {
158 let data_point = self.points.row(i);
159 if Self::is_same_point(&data_point, point) {
160 return Ok(self.values[i]);
161 }
162 }
163
164 let (indices, distances) = match self.n_neighbors {
166 Some(k) => {
167 self.kdtree
169 .query(point.as_slice().expect("Operation failed"), k)?
170 }
171 None => {
172 let mut indices = Vec::with_capacity(self.n_points);
174 let mut distances = Vec::with_capacity(self.n_points);
175
176 for i in 0..self.n_points {
177 let data_point = self.points.row(i);
178 let dist_sq = Self::squared_distance(&data_point, point);
179 indices.push(i);
180 distances.push(dist_sq);
181 }
182
183 (indices, distances)
184 }
185 };
186
187 let mut weighted_sum = 0.0;
189 let mut weight_sum = 0.0;
190
191 for i in 0..indices.len() {
192 let dist_sq = distances[i];
193
194 if dist_sq < 1e-10 {
196 return Ok(self.values[indices[i]]);
197 }
198
199 let weight = 1.0 / dist_sq.powf(self.power / 2.0);
201
202 weighted_sum += weight * self.values[indices[i]];
203 weight_sum += weight;
204 }
205
206 if weight_sum > 0.0 {
207 Ok(weighted_sum / weight_sum)
208 } else {
209 Err(SpatialError::ComputationError(
211 "Zero weight sum in IDW interpolation".to_string(),
212 ))
213 }
214 }
215
216 pub fn interpolate_many(&self, points: &ArrayView2<'_, f64>) -> SpatialResult<Array1<f64>> {
230 if points.ncols() != self.dim {
232 return Err(SpatialError::DimensionError(format!(
233 "Query _points have dimension {}, expected {}",
234 points.ncols(),
235 self.dim
236 )));
237 }
238
239 let n_queries = points.nrows();
240 let mut results = Array1::zeros(n_queries);
241
242 for i in 0..n_queries {
244 let point = points.row(i);
245 results[i] = self.interpolate(&point)?;
246 }
247
248 Ok(results)
249 }
250
251 pub fn power(&self) -> f64 {
253 self.power
254 }
255
256 pub fn n_neighbors(&self) -> Option<usize> {
258 self.n_neighbors
259 }
260
261 pub fn set_power(&mut self, power: f64) -> SpatialResult<()> {
271 if power < 0.0 {
272 return Err(SpatialError::ValueError(format!(
273 "Power parameter must be non-negative, got {power}"
274 )));
275 }
276
277 self.power = power;
278 Ok(())
279 }
280
281 pub fn set_n_neighbors(&mut self, _nneighbors: Option<usize>) -> SpatialResult<()> {
291 if let Some(k) = _nneighbors {
292 if k == 0 {
293 return Err(SpatialError::ValueError(
294 "Number of _neighbors must be positive".to_string(),
295 ));
296 }
297 if k > self.n_points {
298 return Err(SpatialError::ValueError(format!(
299 "Number of _neighbors ({}) cannot exceed number of points ({})",
300 k, self.n_points
301 )));
302 }
303 }
304
305 self.n_neighbors = _nneighbors;
306 Ok(())
307 }
308
309 fn is_same_point(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> bool {
320 Self::squared_distance(p1, p2) < 1e-10
321 }
322
323 fn squared_distance(p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
334 let mut sum_sq = 0.0;
335 for i in 0..p1.len().min(p2.len()) {
336 let diff = p1[i] - p2[i];
337 sum_sq += diff * diff;
338 }
339 sum_sq
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use approx::assert_relative_eq;
347 use scirs2_core::ndarray::array;
348
349 #[test]
350 fn test_idw_interpolation_basic() {
351 let points = array![
353 [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], ];
358
359 let values = array![0.0, 1.0, 1.0, 2.0];
361
362 for power in &[1.0, 2.0, 3.0] {
364 let interp = IDWInterpolator::new(&points.view(), &values.view(), *power, None)
366 .expect("Operation failed");
367
368 let val_00 = interp
370 .interpolate(&array![0.0, 0.0].view())
371 .expect("Operation failed");
372 let val_10 = interp
373 .interpolate(&array![1.0, 0.0].view())
374 .expect("Operation failed");
375 let val_01 = interp
376 .interpolate(&array![0.0, 1.0].view())
377 .expect("Operation failed");
378 let val_11 = interp
379 .interpolate(&array![1.0, 1.0].view())
380 .expect("Operation failed");
381
382 assert_relative_eq!(val_00, 0.0, epsilon = 1e-10);
383 assert_relative_eq!(val_10, 1.0, epsilon = 1e-10);
384 assert_relative_eq!(val_01, 1.0, epsilon = 1e-10);
385 assert_relative_eq!(val_11, 2.0, epsilon = 1e-10);
386
387 let val_center = interp
389 .interpolate(&array![0.5, 0.5].view())
390 .expect("Operation failed");
391 assert_relative_eq!(val_center, 1.0, epsilon = 0.1);
392 }
393 }
394
395 #[test]
396 fn test_idw_with_neighbors() {
397 let points = array![
399 [0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5], [0.2, 0.8], [0.8, 0.2], [0.3, 0.3], [0.7, 0.7], ];
409
410 let values = Array1::from_vec(
412 points
413 .rows()
414 .into_iter()
415 .map(|row| row[0] + row[1])
416 .collect(),
417 );
418
419 let interp_all = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
421 .expect("Operation failed");
422
423 let interp_3 = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(3))
424 .expect("Operation failed");
425
426 let test_point = array![0.6, 0.4];
428
429 let val_all = interp_all
430 .interpolate(&test_point.view())
431 .expect("Operation failed");
432 let val_3 = interp_3
433 .interpolate(&test_point.view())
434 .expect("Operation failed");
435
436 assert_relative_eq!(val_all, 1.0, epsilon = 0.1);
438 assert_relative_eq!(val_3, 1.0, epsilon = 0.1);
439
440 }
444
445 #[test]
446 fn test_interpolate_many() {
447 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
449
450 let values = array![0.0, 1.0, 1.0, 2.0];
452
453 let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
455 .expect("Operation failed");
456
457 let query_points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5],];
459
460 let results = interp
461 .interpolate_many(&query_points.view())
462 .expect("Operation failed");
463
464 assert_eq!(results.len(), 5);
465 assert_relative_eq!(results[0], 0.0, epsilon = 1e-10);
466 assert_relative_eq!(results[1], 1.0, epsilon = 1e-10);
467 assert_relative_eq!(results[2], 1.0, epsilon = 1e-10);
468 assert_relative_eq!(results[3], 2.0, epsilon = 1e-10);
469 assert_relative_eq!(results[4], 1.0, epsilon = 0.1);
470 }
471
472 #[test]
473 fn test_setter_methods() {
474 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0],];
476
477 let values = array![0.0, 1.0, 1.0, 2.0];
478
479 let mut interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
480 .expect("Operation failed");
481
482 assert_eq!(interp.power(), 2.0);
484 assert_eq!(interp.n_neighbors(), None);
485
486 interp.set_power(3.0).expect("Operation failed");
487 assert_eq!(interp.power(), 3.0);
488
489 interp.set_n_neighbors(Some(2)).expect("Operation failed");
490 assert_eq!(interp.n_neighbors(), Some(2));
491
492 let result = interp.set_power(-1.0);
494 assert!(result.is_err());
495
496 let result = interp.set_n_neighbors(Some(0));
497 assert!(result.is_err());
498
499 let result = interp.set_n_neighbors(Some(10));
500 assert!(result.is_err());
501 }
502
503 #[test]
504 fn test_error_handling() {
505 let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
507 let values = array![0.0, 1.0, 1.0];
508
509 let interp = IDWInterpolator::new(&points.view(), &values.view(), 2.0, None)
510 .expect("Operation failed");
511
512 let result = interp.interpolate(&array![0.0].view());
513 assert!(result.is_err());
514
515 let result = IDWInterpolator::new(&points.view(), &values.view(), -1.0, None);
517 assert!(result.is_err());
518
519 let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(0));
521 assert!(result.is_err());
522
523 let result = IDWInterpolator::new(&points.view(), &values.view(), 2.0, Some(10));
524 assert!(result.is_err());
525 }
526}