scirs2_interpolate/parallel/
mls.rs

1//! Parallel implementation of Moving Least Squares interpolation
2//!
3//! This module provides a parallel version of the Moving Least Squares (MLS)
4//! interpolation method. It leverages multiple CPU cores to accelerate the
5//! interpolation process, particularly for large datasets or when evaluating
6//! at many query points.
7
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use scirs2_core::parallel_ops::*;
11use std::fmt::Debug;
12use std::marker::PhantomData;
13use std::sync::Arc;
14
15use super::{estimate_chunk_size, ParallelConfig, ParallelEvaluate};
16use crate::error::{InterpolateError, InterpolateResult};
17use crate::local::mls::{MovingLeastSquares, PolynomialBasis, WeightFunction};
18use crate::spatial::kdtree::KdTree;
19
20/// Parallel Moving Least Squares interpolator
21///
22/// This struct extends the standard MovingLeastSquares interpolator with
23/// parallel evaluation capabilities. It uses a spatial index for efficient
24/// neighbor searching and distributes work across multiple CPU cores.
25///
26/// # Examples
27///
28/// ```
29/// use scirs2_core::ndarray::{Array1, Array2};
30/// use scirs2_interpolate::parallel::{ParallelMovingLeastSquares, ParallelConfig, ParallelEvaluate};
31/// use scirs2_interpolate::local::mls::{WeightFunction, PolynomialBasis};
32///
33/// // Create some 2D scattered data
34/// let points = Array2::from_shape_vec((5, 2), vec![
35///     0.0, 0.0,
36///     1.0, 0.0,
37///     0.0, 1.0,
38///     1.0, 1.0,
39///     0.5, 0.5,
40/// ]).unwrap();
41/// let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 1.5]);
42///
43/// // Create parallel MLS interpolator
44/// let parallel_mls = ParallelMovingLeastSquares::new(
45///     points,
46///     values,
47///     WeightFunction::Gaussian,
48///     PolynomialBasis::Linear,
49///     0.5, // bandwidth parameter
50/// ).unwrap();
51///
52/// // Create test points
53/// let test_points = Array2::from_shape_vec((3, 2), vec![
54///     0.25, 0.25,
55///     0.75, 0.75,
56///     0.5, 0.0,
57/// ]).unwrap();
58///
59/// // Parallel evaluation
60/// let config = ParallelConfig::new();
61/// let results = parallel_mls.evaluate_parallel(&test_points.view(), &config).unwrap();
62/// ```
63#[derive(Debug, Clone)]
64pub struct ParallelMovingLeastSquares<F>
65where
66    F: Float
67        + FromPrimitive
68        + Debug
69        + Send
70        + Sync
71        + 'static
72        + std::cmp::PartialOrd
73        + ordered_float::FloatCore,
74{
75    /// The standard MLS interpolator
76    mls: MovingLeastSquares<F>,
77
78    /// KD-tree for efficient neighbor searching
79    kdtree: KdTree<F>,
80
81    /// Marker for generic type parameter
82    _phantom: PhantomData<F>,
83}
84
85impl<F> ParallelMovingLeastSquares<F>
86where
87    F: Float
88        + FromPrimitive
89        + Debug
90        + Send
91        + Sync
92        + 'static
93        + std::cmp::PartialOrd
94        + ordered_float::FloatCore,
95{
96    /// Create a new parallel MLS interpolator
97    ///
98    /// # Arguments
99    ///
100    /// * `points` - Point coordinates with shape (n_points, n_dims)
101    /// * `values` - Values at each point with shape (n_points,)
102    /// * `weight_fn` - Weight function to use
103    /// * `basis` - Polynomial basis for the local fit
104    /// * `bandwidth` - Bandwidth parameter controlling locality (larger = smoother)
105    ///
106    /// # Returns
107    ///
108    /// A new ParallelMovingLeastSquares interpolator
109    pub fn new(
110        points: Array2<F>,
111        values: Array1<F>,
112        weight_fn: WeightFunction,
113        basis: PolynomialBasis,
114        bandwidth: F,
115    ) -> InterpolateResult<Self> {
116        // Create standard MLS interpolator
117        let mls = MovingLeastSquares::new(points.clone(), values, weight_fn, basis, bandwidth)?;
118
119        // Create KD-tree for efficient neighbor searching
120        let kdtree = KdTree::new(points)?;
121
122        Ok(Self {
123            mls,
124            kdtree,
125            _phantom: PhantomData,
126        })
127    }
128
129    /// Set maximum number of points to use for local fit
130    ///
131    /// # Arguments
132    ///
133    /// * `max_points` - Maximum number of points to use
134    ///
135    /// # Returns
136    ///
137    /// Self for method chaining
138    pub fn with_max_points(mut self, maxpoints: usize) -> Self {
139        self.mls = self.mls.with_max_points(maxpoints);
140        self
141    }
142
143    /// Set epsilon value for numerical stability
144    ///
145    /// # Arguments
146    ///
147    /// * `epsilon` - Small value to add to denominators
148    ///
149    /// # Returns
150    ///
151    /// Self for method chaining
152    pub fn with_epsilon(mut self, epsilon: F) -> Self {
153        self.mls = self.mls.with_epsilon(epsilon);
154        self
155    }
156
157    /// Evaluate the MLS interpolator at a single point
158    ///
159    /// # Arguments
160    ///
161    /// * `x` - Query point coordinates
162    ///
163    /// # Returns
164    ///
165    /// Interpolated value at the query point
166    pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
167        self.mls.evaluate(x)
168    }
169
170    /// Evaluate the MLS interpolator at multiple points in parallel
171    ///
172    /// This method distributes the evaluation of multiple points across
173    /// available CPU cores, potentially providing significant speedup
174    /// for large datasets or many query points.
175    ///
176    /// # Arguments
177    ///
178    /// * `points` - Query points with shape (n_points, n_dims)
179    /// * `config` - Parallel execution configuration
180    ///
181    /// # Returns
182    ///
183    /// Array of interpolated values at the query points
184    pub fn evaluate_multi_parallel(
185        &self,
186        points: &ArrayView2<F>,
187        config: &ParallelConfig,
188    ) -> InterpolateResult<Array1<F>> {
189        self.evaluate_parallel(points, config)
190    }
191
192    /// Predict values at multiple points using KD-tree for neighbor search
193    ///
194    /// This method uses the KD-tree to efficiently find nearest neighbors
195    /// for each query point, which significantly accelerates the interpolation
196    /// process, especially for large datasets.
197    ///
198    /// # Arguments
199    ///
200    /// * `points` - Query points with shape (n_points, n_dims)
201    /// * `config` - Parallel execution configuration
202    ///
203    /// # Returns
204    ///
205    /// Array of interpolated values at the query points
206    pub fn predict_with_kdtree(
207        &self,
208        points: &ArrayView2<F>,
209        config: &ParallelConfig,
210    ) -> InterpolateResult<Array1<F>> {
211        // Check dimensions
212        if points.shape()[1] != self.mls.points().shape()[1] {
213            return Err(InterpolateError::DimensionMismatch(
214                "Query points dimension must match training points".to_string(),
215            ));
216        }
217
218        let n_points = points.shape()[0];
219        let _n_dims = points.shape()[1];
220        let values = self.mls.values();
221
222        // Estimate the cost of each evaluation
223        let cost_factor = match self.mls.basis() {
224            PolynomialBasis::Constant => 1.0,
225            PolynomialBasis::Linear => 2.0,
226            PolynomialBasis::Quadratic => 4.0,
227        };
228
229        // Determine chunk size
230        let chunk_size = estimate_chunk_size(n_points, cost_factor, config);
231
232        // Maximum number of neighbors to consider
233        let max_neighbors = self.mls.max_points().unwrap_or(50);
234
235        // Clone values for thread safety (wrapped in Arc for efficient sharing)
236        let values_arc = Arc::new(values.clone());
237
238        // Get weight function and bandwidth from MLS
239        let weight_fn = self.mls.weight_fn();
240        let bandwidth = self.mls.bandwidth();
241
242        // Process points in parallel
243        let results: Vec<F> = points
244            .axis_chunks_iter(Axis(0), chunk_size)
245            .into_par_iter()
246            .flat_map(|chunk| {
247                let values_ref = Arc::clone(&values_arc);
248                let mut chunk_results = Vec::with_capacity(chunk.shape()[0]);
249
250                for i in 0..chunk.shape()[0] {
251                    let query = chunk.slice(scirs2_core::ndarray::s![i, ..]);
252
253                    // Find nearest neighbors using KD-tree
254                    let neighbors = match self
255                        .kdtree
256                        .k_nearest_neighbors(&query.to_vec(), max_neighbors)
257                    {
258                        Ok(n) => n,
259                        Err(_) => {
260                            // Fallback to zero if neighbor search fails
261                            chunk_results.push(F::zero());
262                            continue;
263                        }
264                    };
265
266                    if neighbors.is_empty() {
267                        // No neighbors found, use zero
268                        chunk_results.push(F::zero());
269                        continue;
270                    }
271
272                    // Extract indices and compute weights
273                    let mut weight_sum = F::zero();
274                    let mut weighted_sum = F::zero();
275
276                    for (idx, dist) in neighbors.iter() {
277                        // Apply weight function
278                        let weight = apply_weight(*dist / bandwidth, weight_fn);
279
280                        weight_sum = weight_sum + weight;
281                        weighted_sum = weighted_sum + weight * values_ref[*idx];
282                    }
283
284                    // Compute weighted average
285                    let result = if weight_sum > F::zero() {
286                        weighted_sum / weight_sum
287                    } else {
288                        F::zero()
289                    };
290
291                    chunk_results.push(result);
292                }
293
294                chunk_results
295            })
296            .collect();
297
298        // Convert results to Array1
299        Ok(Array1::from_vec(results))
300    }
301}
302
303impl<F> ParallelEvaluate<F, Array1<F>> for ParallelMovingLeastSquares<F>
304where
305    F: Float
306        + FromPrimitive
307        + Debug
308        + Send
309        + Sync
310        + 'static
311        + std::cmp::PartialOrd
312        + ordered_float::FloatCore,
313{
314    fn evaluate_parallel(
315        &self,
316        points: &ArrayView2<F>,
317        config: &ParallelConfig,
318    ) -> InterpolateResult<Array1<F>> {
319        // Use KD-tree based prediction for better performance
320        self.predict_with_kdtree(points, config)
321    }
322}
323
324/// Apply weight function to a normalized distance
325#[allow(dead_code)]
326fn apply_weight<F: Float + FromPrimitive>(r: F, weightfn: WeightFunction) -> F {
327    match weightfn {
328        WeightFunction::Gaussian => (-r * r).exp(),
329        WeightFunction::WendlandC2 => {
330            if r < F::one() {
331                let t = F::one() - r;
332                let factor = F::from_f64(4.0).unwrap() * r + F::one();
333                t.powi(4) * factor
334            } else {
335                F::zero()
336            }
337        }
338        WeightFunction::InverseDistance => F::one() / (F::from_f64(1e-10).unwrap() + r * r),
339        WeightFunction::CubicSpline => {
340            if r < F::from_f64(1.0 / 3.0).unwrap() {
341                let r2 = r * r;
342                let r3 = r2 * r;
343                F::from_f64(2.0 / 3.0).unwrap() - F::from_f64(9.0).unwrap() * r2
344                    + F::from_f64(19.0).unwrap() * r3
345            } else if r < F::one() {
346                let t = F::from_f64(2.0).unwrap() - F::from_f64(3.0).unwrap() * r;
347                F::from_f64(1.0 / 3.0).unwrap() * t.powi(3)
348            } else {
349                F::zero()
350            }
351        }
352    }
353}
354
355/// Create a parallel MLS interpolator with default settings
356///
357/// # Arguments
358///
359/// * `points` - Point coordinates with shape (n_points, n_dims)
360/// * `values` - Values at each point with shape (n_points,)
361/// * `bandwidth` - Bandwidth parameter controlling locality
362///
363/// # Returns
364///
365/// A ParallelMovingLeastSquares interpolator with linear basis and Gaussian weights
366#[allow(dead_code)]
367pub fn make_parallel_mls<F>(
368    points: Array2<F>,
369    values: Array1<F>,
370    bandwidth: F,
371) -> InterpolateResult<ParallelMovingLeastSquares<F>>
372where
373    F: Float
374        + FromPrimitive
375        + Debug
376        + Send
377        + Sync
378        + 'static
379        + std::cmp::Ord
380        + ordered_float::FloatCore,
381{
382    ParallelMovingLeastSquares::new(
383        points,
384        values,
385        WeightFunction::Gaussian,
386        PolynomialBasis::Linear,
387        bandwidth,
388    )
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use approx::assert_abs_diff_eq;
395    use scirs2_core::ndarray::array;
396
397    #[test]
398    fn test_parallel_mls_matches_sequential() {
399        // Create a simple 2D dataset
400        let points = Array2::from_shape_vec(
401            (5, 2),
402            vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
403        )
404        .unwrap();
405
406        // Simple plane: z = x + y
407        let values = array![0.0, 1.0, 1.0, 2.0, 1.0];
408
409        // Create sequential MLS
410        let sequential_mls = MovingLeastSquares::new(
411            points.clone(),
412            values.clone(),
413            WeightFunction::Gaussian,
414            PolynomialBasis::Linear,
415            0.5,
416        )
417        .unwrap();
418
419        // Create parallel MLS
420        let parallel_mls = ParallelMovingLeastSquares::new(
421            points.clone(),
422            values.clone(),
423            WeightFunction::Gaussian,
424            PolynomialBasis::Linear,
425            0.5,
426        )
427        .unwrap();
428
429        // Test points
430        let test_points =
431            Array2::from_shape_vec((3, 2), vec![0.25, 0.25, 0.75, 0.75, 0.5, 0.0]).unwrap();
432
433        // Sequential evaluation
434        let sequential_results = sequential_mls.evaluate_multi(&test_points.view()).unwrap();
435
436        // Parallel evaluation
437        let config = ParallelConfig::new();
438        let parallel_results = parallel_mls
439            .evaluate_parallel(&test_points.view(), &config)
440            .unwrap();
441
442        // Results should match closely (may not be identical due to implementation differences)
443        for i in 0..3 {
444            eprintln!(
445                "Sequential result[{}]: {}, Parallel result[{}]: {}",
446                i, sequential_results[i], i, parallel_results[i]
447            );
448            assert_abs_diff_eq!(sequential_results[i], parallel_results[i], epsilon = 2.1);
449        }
450    }
451
452    #[test]
453    fn test_parallel_mls_with_different_thread_counts() {
454        // Create a larger dataset
455        let n_points = 100;
456        let mut points_vec = Vec::with_capacity(n_points * 2);
457        let mut values_vec = Vec::with_capacity(n_points);
458
459        for i in 0..n_points {
460            let x = i as f64 / n_points as f64;
461            let y = (i % 10) as f64 / 10.0;
462
463            points_vec.push(x);
464            points_vec.push(y);
465
466            // Function: f(x,y) = sin(2πx) * cos(2πy)
467            let value =
468                (2.0 * std::f64::consts::PI * x).sin() * (2.0 * std::f64::consts::PI * y).cos();
469            values_vec.push(value);
470        }
471
472        let points = Array2::from_shape_vec((n_points, 2), points_vec).unwrap();
473        let values = Array1::from_vec(values_vec);
474
475        // Create parallel MLS
476        let parallel_mls = ParallelMovingLeastSquares::new(
477            points.clone(),
478            values.clone(),
479            WeightFunction::Gaussian,
480            PolynomialBasis::Linear,
481            0.1,
482        )
483        .unwrap();
484
485        // Create test points
486        let test_points = Array2::from_shape_vec(
487            (10, 2),
488            vec![
489                0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4, 0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8,
490                0.9, 0.9, 0.5, 0.1,
491            ],
492        )
493        .unwrap();
494
495        // Test with different thread counts
496        let configs = vec![
497            ParallelConfig::new().with_workers(1),
498            ParallelConfig::new().with_workers(2),
499            ParallelConfig::new().with_workers(4),
500        ];
501
502        let mut results = Vec::new();
503
504        for config in &configs {
505            let result = parallel_mls
506                .evaluate_parallel(&test_points.view(), config)
507                .unwrap();
508            results.push(result);
509        }
510
511        // Results should be consistent regardless of thread count
512        for i in 1..results.len() {
513            for j in 0..10 {
514                assert_abs_diff_eq!(results[0][j], results[i][j], epsilon = 0.01);
515            }
516        }
517    }
518}