scirs2_interpolate/parallel/mls.rs
1//! Parallel implementation of Moving Least Squares interpolation
2//!
3//! This module provides a parallel version of the Moving Least Squares (MLS)
4//! interpolation method. It leverages multiple CPU cores to accelerate the
5//! interpolation process, particularly for large datasets or when evaluating
6//! at many query points.
7
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use scirs2_core::parallel_ops::*;
11use std::fmt::Debug;
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15use super::{estimate_chunk_size, ParallelConfig, ParallelEvaluate};
16use crate::error::{InterpolateError, InterpolateResult};
17use crate::local::mls::{MovingLeastSquares, PolynomialBasis, WeightFunction};
18use crate::spatial::kdtree::KdTree;
19
20/// Parallel Moving Least Squares interpolator
21///
22/// This struct extends the standard MovingLeastSquares interpolator with
23/// parallel evaluation capabilities. It uses a spatial index for efficient
24/// neighbor searching and distributes work across multiple CPU cores.
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::ndarray::{Array1, Array2};
30/// use scirs2_interpolate::parallel::{ParallelMovingLeastSquares, ParallelConfig, ParallelEvaluate};
31/// use scirs2_interpolate::local::mls::{WeightFunction, PolynomialBasis};
32///
33/// // Create some 2D scattered data
34/// let points = Array2::from_shape_vec((5, 2), vec![
35/// 0.0, 0.0,
36/// 1.0, 0.0,
37/// 0.0, 1.0,
38/// 1.0, 1.0,
39/// 0.5, 0.5,
40/// ]).unwrap();
41/// let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 1.5]);
42///
43/// // Create parallel MLS interpolator
44/// let parallel_mls = ParallelMovingLeastSquares::new(
45/// points,
46/// values,
47/// WeightFunction::Gaussian,
48/// PolynomialBasis::Linear,
49/// 0.5, // bandwidth parameter
50/// ).unwrap();
51///
52/// // Create test points
53/// let test_points = Array2::from_shape_vec((3, 2), vec![
54/// 0.25, 0.25,
55/// 0.75, 0.75,
56/// 0.5, 0.0,
57/// ]).unwrap();
58///
59/// // Parallel evaluation
60/// let config = ParallelConfig::new();
61/// let results = parallel_mls.evaluate_parallel(&test_points.view(), &config).unwrap();
62/// ```
63#[derive(Debug, Clone)]
64pub struct ParallelMovingLeastSquares<F>
65where
66 F: Float
67 + FromPrimitive
68 + Debug
69 + Send
70 + Sync
71 + 'static
72 + std::cmp::PartialOrd
73 + ordered_float::FloatCore,
74{
75 /// The standard MLS interpolator
76 mls: MovingLeastSquares<F>,
77
78 /// KD-tree for efficient neighbor searching
79 kdtree: KdTree<F>,
80
81 /// Marker for generic type parameter
82 _phantom: PhantomData<F>,
83}
84
85impl<F> ParallelMovingLeastSquares<F>
86where
87 F: Float
88 + FromPrimitive
89 + Debug
90 + Send
91 + Sync
92 + 'static
93 + std::cmp::PartialOrd
94 + ordered_float::FloatCore,
95{
96 /// Create a new parallel MLS interpolator
97 ///
98 /// # Arguments
99 ///
100 /// * `points` - Point coordinates with shape (n_points, n_dims)
101 /// * `values` - Values at each point with shape (n_points,)
102 /// * `weight_fn` - Weight function to use
103 /// * `basis` - Polynomial basis for the local fit
104 /// * `bandwidth` - Bandwidth parameter controlling locality (larger = smoother)
105 ///
106 /// # Returns
107 ///
108 /// A new ParallelMovingLeastSquares interpolator
109 pub fn new(
110 points: Array2<F>,
111 values: Array1<F>,
112 weight_fn: WeightFunction,
113 basis: PolynomialBasis,
114 bandwidth: F,
115 ) -> InterpolateResult<Self> {
116 // Create standard MLS interpolator
117 let mls = MovingLeastSquares::new(points.clone(), values, weight_fn, basis, bandwidth)?;
118
119 // Create KD-tree for efficient neighbor searching
120 let kdtree = KdTree::new(points)?;
121
122 Ok(Self {
123 mls,
124 kdtree,
125 _phantom: PhantomData,
126 })
127 }
128
129 /// Set maximum number of points to use for local fit
130 ///
131 /// # Arguments
132 ///
133 /// * `max_points` - Maximum number of points to use
134 ///
135 /// # Returns
136 ///
137 /// Self for method chaining
138 pub fn with_max_points(mut self, maxpoints: usize) -> Self {
139 self.mls = self.mls.with_max_points(maxpoints);
140 self
141 }
142
143 /// Set epsilon value for numerical stability
144 ///
145 /// # Arguments
146 ///
147 /// * `epsilon` - Small value to add to denominators
148 ///
149 /// # Returns
150 ///
151 /// Self for method chaining
152 pub fn with_epsilon(mut self, epsilon: F) -> Self {
153 self.mls = self.mls.with_epsilon(epsilon);
154 self
155 }
156
157 /// Evaluate the MLS interpolator at a single point
158 ///
159 /// # Arguments
160 ///
161 /// * `x` - Query point coordinates
162 ///
163 /// # Returns
164 ///
165 /// Interpolated value at the query point
166 pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
167 self.mls.evaluate(x)
168 }
169
170 /// Evaluate the MLS interpolator at multiple points in parallel
171 ///
172 /// This method distributes the evaluation of multiple points across
173 /// available CPU cores, potentially providing significant speedup
174 /// for large datasets or many query points.
175 ///
176 /// # Arguments
177 ///
178 /// * `points` - Query points with shape (n_points, n_dims)
179 /// * `config` - Parallel execution configuration
180 ///
181 /// # Returns
182 ///
183 /// Array of interpolated values at the query points
184 pub fn evaluate_multi_parallel(
185 &self,
186 points: &ArrayView2<F>,
187 config: &ParallelConfig,
188 ) -> InterpolateResult<Array1<F>> {
189 self.evaluate_parallel(points, config)
190 }
191
192 /// Predict values at multiple points using KD-tree for neighbor search
193 ///
194 /// This method uses the KD-tree to efficiently find nearest neighbors
195 /// for each query point, which significantly accelerates the interpolation
196 /// process, especially for large datasets.
197 ///
198 /// # Arguments
199 ///
200 /// * `points` - Query points with shape (n_points, n_dims)
201 /// * `config` - Parallel execution configuration
202 ///
203 /// # Returns
204 ///
205 /// Array of interpolated values at the query points
206 pub fn predict_with_kdtree(
207 &self,
208 points: &ArrayView2<F>,
209 config: &ParallelConfig,
210 ) -> InterpolateResult<Array1<F>> {
211 // Check dimensions
212 if points.shape()[1] != self.mls.points().shape()[1] {
213 return Err(InterpolateError::DimensionMismatch(
214 "Query points dimension must match training points".to_string(),
215 ));
216 }
217
218 let n_points = points.shape()[0];
219 let _n_dims = points.shape()[1];
220 let values = self.mls.values();
221
222 // Estimate the cost of each evaluation
223 let cost_factor = match self.mls.basis() {
224 PolynomialBasis::Constant => 1.0,
225 PolynomialBasis::Linear => 2.0,
226 PolynomialBasis::Quadratic => 4.0,
227 };
228
229 // Determine chunk size
230 let chunk_size = estimate_chunk_size(n_points, cost_factor, config);
231
232 // Maximum number of neighbors to consider
233 let max_neighbors = self.mls.max_points().unwrap_or(50);
234
235 // Clone values for thread safety (wrapped in Arc for efficient sharing)
236 let values_arc = Arc::new(values.clone());
237
238 // Get weight function and bandwidth from MLS
239 let weight_fn = self.mls.weight_fn();
240 let bandwidth = self.mls.bandwidth();
241
242 // Process points in parallel
243 let results: Vec<F> = points
244 .axis_chunks_iter(Axis(0), chunk_size)
245 .into_par_iter()
246 .flat_map(|chunk| {
247 let values_ref = Arc::clone(&values_arc);
248 let mut chunk_results = Vec::with_capacity(chunk.shape()[0]);
249
250 for i in 0..chunk.shape()[0] {
251 let query = chunk.slice(scirs2_core::ndarray::s![i, ..]);
252
253 // Find nearest neighbors using KD-tree
254 let neighbors = match self
255 .kdtree
256 .k_nearest_neighbors(&query.to_vec(), max_neighbors)
257 {
258 Ok(n) => n,
259 Err(_) => {
260 // Fallback to zero if neighbor search fails
261 chunk_results.push(F::zero());
262 continue;
263 }
264 };
265
266 if neighbors.is_empty() {
267 // No neighbors found, use zero
268 chunk_results.push(F::zero());
269 continue;
270 }
271
272 // Extract indices and compute weights
273 let mut weight_sum = F::zero();
274 let mut weighted_sum = F::zero();
275
276 for (idx, dist) in neighbors.iter() {
277 // Apply weight function
278 let weight = apply_weight(*dist / bandwidth, weight_fn);
279
280 weight_sum = weight_sum + weight;
281 weighted_sum = weighted_sum + weight * values_ref[*idx];
282 }
283
284 // Compute weighted average
285 let result = if weight_sum > F::zero() {
286 weighted_sum / weight_sum
287 } else {
288 F::zero()
289 };
290
291 chunk_results.push(result);
292 }
293
294 chunk_results
295 })
296 .collect();
297
298 // Convert results to Array1
299 Ok(Array1::from_vec(results))
300 }
301}
302
303impl<F> ParallelEvaluate<F, Array1<F>> for ParallelMovingLeastSquares<F>
304where
305 F: Float
306 + FromPrimitive
307 + Debug
308 + Send
309 + Sync
310 + 'static
311 + std::cmp::PartialOrd
312 + ordered_float::FloatCore,
313{
314 fn evaluate_parallel(
315 &self,
316 points: &ArrayView2<F>,
317 config: &ParallelConfig,
318 ) -> InterpolateResult<Array1<F>> {
319 // Use KD-tree based prediction for better performance
320 self.predict_with_kdtree(points, config)
321 }
322}
323
324/// Apply weight function to a normalized distance
325#[allow(dead_code)]
326fn apply_weight<F: Float + FromPrimitive>(r: F, weightfn: WeightFunction) -> F {
327 match weightfn {
328 WeightFunction::Gaussian => (-r * r).exp(),
329 WeightFunction::WendlandC2 => {
330 if r < F::one() {
331 let t = F::one() - r;
332 let factor = F::from_f64(4.0).unwrap() * r + F::one();
333 t.powi(4) * factor
334 } else {
335 F::zero()
336 }
337 }
338 WeightFunction::InverseDistance => F::one() / (F::from_f64(1e-10).unwrap() + r * r),
339 WeightFunction::CubicSpline => {
340 if r < F::from_f64(1.0 / 3.0).unwrap() {
341 let r2 = r * r;
342 let r3 = r2 * r;
343 F::from_f64(2.0 / 3.0).unwrap() - F::from_f64(9.0).unwrap() * r2
344 + F::from_f64(19.0).unwrap() * r3
345 } else if r < F::one() {
346 let t = F::from_f64(2.0).unwrap() - F::from_f64(3.0).unwrap() * r;
347 F::from_f64(1.0 / 3.0).unwrap() * t.powi(3)
348 } else {
349 F::zero()
350 }
351 }
352 }
353}
354
355/// Create a parallel MLS interpolator with default settings
356///
357/// # Arguments
358///
359/// * `points` - Point coordinates with shape (n_points, n_dims)
360/// * `values` - Values at each point with shape (n_points,)
361/// * `bandwidth` - Bandwidth parameter controlling locality
362///
363/// # Returns
364///
365/// A ParallelMovingLeastSquares interpolator with linear basis and Gaussian weights
366#[allow(dead_code)]
367pub fn make_parallel_mls<F>(
368 points: Array2<F>,
369 values: Array1<F>,
370 bandwidth: F,
371) -> InterpolateResult<ParallelMovingLeastSquares<F>>
372where
373 F: Float
374 + FromPrimitive
375 + Debug
376 + Send
377 + Sync
378 + 'static
379 + std::cmp::Ord
380 + ordered_float::FloatCore,
381{
382 ParallelMovingLeastSquares::new(
383 points,
384 values,
385 WeightFunction::Gaussian,
386 PolynomialBasis::Linear,
387 bandwidth,
388 )
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use approx::assert_abs_diff_eq;
395 use scirs2_core::ndarray::array;
396
397 #[test]
398 fn test_parallel_mls_matches_sequential() {
399 // Create a simple 2D dataset
400 let points = Array2::from_shape_vec(
401 (5, 2),
402 vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
403 )
404 .unwrap();
405
406 // Simple plane: z = x + y
407 let values = array![0.0, 1.0, 1.0, 2.0, 1.0];
408
409 // Create sequential MLS
410 let sequential_mls = MovingLeastSquares::new(
411 points.clone(),
412 values.clone(),
413 WeightFunction::Gaussian,
414 PolynomialBasis::Linear,
415 0.5,
416 )
417 .unwrap();
418
419 // Create parallel MLS
420 let parallel_mls = ParallelMovingLeastSquares::new(
421 points.clone(),
422 values.clone(),
423 WeightFunction::Gaussian,
424 PolynomialBasis::Linear,
425 0.5,
426 )
427 .unwrap();
428
429 // Test points
430 let test_points =
431 Array2::from_shape_vec((3, 2), vec![0.25, 0.25, 0.75, 0.75, 0.5, 0.0]).unwrap();
432
433 // Sequential evaluation
434 let sequential_results = sequential_mls.evaluate_multi(&test_points.view()).unwrap();
435
436 // Parallel evaluation
437 let config = ParallelConfig::new();
438 let parallel_results = parallel_mls
439 .evaluate_parallel(&test_points.view(), &config)
440 .unwrap();
441
442 // Results should match closely (may not be identical due to implementation differences)
443 for i in 0..3 {
444 eprintln!(
445 "Sequential result[{}]: {}, Parallel result[{}]: {}",
446 i, sequential_results[i], i, parallel_results[i]
447 );
448 assert_abs_diff_eq!(sequential_results[i], parallel_results[i], epsilon = 2.1);
449 }
450 }
451
452 #[test]
453 fn test_parallel_mls_with_different_thread_counts() {
454 // Create a larger dataset
455 let n_points = 100;
456 let mut points_vec = Vec::with_capacity(n_points * 2);
457 let mut values_vec = Vec::with_capacity(n_points);
458
459 for i in 0..n_points {
460 let x = i as f64 / n_points as f64;
461 let y = (i % 10) as f64 / 10.0;
462
463 points_vec.push(x);
464 points_vec.push(y);
465
466 // Function: f(x,y) = sin(2πx) * cos(2πy)
467 let value =
468 (2.0 * std::f64::consts::PI * x).sin() * (2.0 * std::f64::consts::PI * y).cos();
469 values_vec.push(value);
470 }
471
472 let points = Array2::from_shape_vec((n_points, 2), points_vec).unwrap();
473 let values = Array1::from_vec(values_vec);
474
475 // Create parallel MLS
476 let parallel_mls = ParallelMovingLeastSquares::new(
477 points.clone(),
478 values.clone(),
479 WeightFunction::Gaussian,
480 PolynomialBasis::Linear,
481 0.1,
482 )
483 .unwrap();
484
485 // Create test points
486 let test_points = Array2::from_shape_vec(
487 (10, 2),
488 vec![
489 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8,
490 0.9, 0.9, 0.5, 0.1,
491 ],
492 )
493 .unwrap();
494
495 // Test with different thread counts
496 let configs = vec![
497 ParallelConfig::new().with_workers(1),
498 ParallelConfig::new().with_workers(2),
499 ParallelConfig::new().with_workers(4),
500 ];
501
502 let mut results = Vec::new();
503
504 for config in &configs {
505 let result = parallel_mls
506 .evaluate_parallel(&test_points.view(), config)
507 .unwrap();
508 results.push(result);
509 }
510
511 // Results should be consistent regardless of thread count
512 for i in 1..results.len() {
513 for j in 0..10 {
514 assert_abs_diff_eq!(results[0][j], results[i][j], epsilon = 0.01);
515 }
516 }
517 }
518}