scirs2_interpolate/local/
mls.rs

1//! Moving Least Squares Interpolation
2//!
3//! This module provides an implementation of Moving Least Squares (MLS) interpolation,
4//! which is particularly useful for scattered data with potentially noisy values.
5//! MLS creates a smooth approximation function by fitting local polynomials at each
6//! evaluation point using weighted least squares where closer points have higher weights.
7//!
8//! The technique is popular in computer graphics, mesh processing, and scientific computing
9//! for its ability to handle irregularly spaced data and provide smooth results.
10
11use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
12use scirs2_core::numeric::{Float, FromPrimitive};
13use std::fmt::Debug;
14use std::marker::PhantomData;
15
16use crate::error::{InterpolateError, InterpolateResult};
17
18/// Weight function types for MLS
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum WeightFunction {
21    /// Gaussian weight: w(r) = exp(-r²/h²)
22    Gaussian,
23
24    /// Wendland C2 compactly supported function
25    /// w(r) = (1-r/h)⁴(4r/h+1) for r < h, 0 otherwise
26    WendlandC2,
27
28    /// Inverse distance: w(r) = 1/(ε + r²)
29    InverseDistance,
30
31    /// Cubic spline with compact support
32    /// w(r) = 2/3 - 9r²/h² + 19r³/h³ for r < h/3
33    /// w(r) = 1/3 * (2 - 3r/h)³ for h/3 < r < h
34    /// w(r) = 0 for r > h
35    CubicSpline,
36}
37
38/// Polynomial basis types for the local approximation
39#[derive(Debug, Clone, Copy, PartialEq)]
40pub enum PolynomialBasis {
41    /// Constant basis: [1]
42    Constant,
43
44    /// Linear basis: [1, x, y, ...]
45    Linear,
46
47    /// Quadratic basis: [1, x, y, ..., x², xy, y², ...]
48    Quadratic,
49}
50
51/// Moving Least Squares interpolator for scattered data
52///
53/// This interpolator uses a weighted least squares fit at each evaluation point,
54/// where the weights depend on the distance from the evaluation point to the data points.
55/// The result is a smooth function that approximates the data points.
56///
57/// # Examples
58///
59/// ```
60/// # #[cfg(feature = "linalg")]
61/// # {
62/// use scirs2_core::ndarray::{Array1, Array2};
63/// use scirs2_interpolate::local::mls::{MovingLeastSquares, WeightFunction, PolynomialBasis};
64///
65/// // Create some 2D scattered data
66/// let points = Array2::from_shape_vec((5, 2), vec![
67///     0.0, 0.0,
68///     1.0, 0.0,
69///     0.0, 1.0,
70///     1.0, 1.0,
71///     0.5, 0.5,
72/// ]).unwrap();
73/// let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 1.5]);
74///
75/// // Create MLS interpolator (simplified configuration for test)
76/// let mls = MovingLeastSquares::<f64>::new(
77///     points,
78///     values,
79///     WeightFunction::Gaussian,
80///     PolynomialBasis::Constant, // Using constant basis to avoid linalg feature requirement
81///     0.5, // bandwidth parameter
82/// ).unwrap();
83///
84/// // Evaluate at a new point
85/// let query = Array1::from_vec(vec![0.25, 0.25]);
86/// let result = mls.evaluate(&query.view()).unwrap();
87/// # }
88/// ```
89#[derive(Debug, Clone)]
90pub struct MovingLeastSquares<F>
91where
92    F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
93{
94    /// Points coordinates (input locations)
95    points: Array2<F>,
96
97    /// Values at points
98    values: Array1<F>,
99
100    /// Weight function to use
101    weight_fn: WeightFunction,
102
103    /// Polynomial basis to use
104    basis: PolynomialBasis,
105
106    /// Bandwidth parameter (h)
107    bandwidth: F,
108
109    /// Small value to add to denominators to avoid division by zero
110    epsilon: F,
111
112    /// Maximum number of points to use (for efficiency)
113    max_points: Option<usize>,
114
115    /// Marker for generic type parameter
116    _phantom: PhantomData<F>,
117}
118
119impl<F> MovingLeastSquares<F>
120where
121    F: Float + FromPrimitive + Debug + 'static + std::cmp::PartialOrd,
122{
123    /// Create a new MovingLeastSquares interpolator
124    ///
125    /// # Arguments
126    ///
127    /// * `points` - Point coordinates with shape (n_points, n_dims)
128    /// * `values` - Values at each point with shape (n_points,)
129    /// * `weight_fn` - Weight function to use
130    /// * `basis` - Polynomial basis for the local fit
131    /// * `bandwidth` - Bandwidth parameter controlling locality (larger = smoother)
132    ///
133    /// # Returns
134    ///
135    /// A new MovingLeastSquares interpolator
136    pub fn new(
137        points: Array2<F>,
138        values: Array1<F>,
139        weight_fn: WeightFunction,
140        basis: PolynomialBasis,
141        bandwidth: F,
142    ) -> InterpolateResult<Self> {
143        // Validate inputs
144        if points.shape()[0] != values.len() {
145            return Err(InterpolateError::DimensionMismatch(
146                "Number of points must match number of values".to_string(),
147            ));
148        }
149
150        if points.shape()[0] < 2 {
151            return Err(InterpolateError::InvalidValue(
152                "At least 2 points are required for MLS interpolation".to_string(),
153            ));
154        }
155
156        if bandwidth <= F::zero() {
157            return Err(InterpolateError::InvalidValue(
158                "Bandwidth parameter must be positive".to_string(),
159            ));
160        }
161
162        Ok(Self {
163            points,
164            values,
165            weight_fn,
166            basis,
167            bandwidth,
168            epsilon: F::from_f64(1e-10).unwrap(),
169            max_points: None,
170            _phantom: PhantomData,
171        })
172    }
173
174    /// Set maximum number of points to use for local fit
175    ///
176    /// This is useful for large datasets where using all points
177    /// would be computationally expensive.
178    ///
179    /// # Arguments
180    ///
181    /// * `max_points` - Maximum number of points to use
182    ///
183    /// # Returns
184    ///
185    /// Self for method chaining
186    pub fn with_max_points(mut self, maxpoints: usize) -> Self {
187        self.max_points = Some(maxpoints);
188        self
189    }
190
191    /// Set epsilon value for numerical stability
192    ///
193    /// # Arguments
194    ///
195    /// * `epsilon` - Small value to add to denominators
196    ///
197    /// # Returns
198    ///
199    /// Self for method chaining
200    pub fn with_epsilon(mut self, epsilon: F) -> Self {
201        self.epsilon = epsilon;
202        self
203    }
204
205    /// Evaluate the MLS approximation at a given point
206    ///
207    /// # Arguments
208    ///
209    /// * `x` - Query point with shape (n_dims,)
210    ///
211    /// # Returns
212    ///
213    /// Interpolated value at the query point
214    pub fn evaluate(&self, x: &ArrayView1<F>) -> InterpolateResult<F> {
215        // Check dimensions
216        if x.len() != self.points.shape()[1] {
217            return Err(InterpolateError::DimensionMismatch(
218                "Query point dimension must match training points".to_string(),
219            ));
220        }
221
222        // Get points to use for local fit
223        let (indices, distances) = self.find_relevant_points(x)?;
224
225        if indices.is_empty() {
226            return Err(InterpolateError::invalid_input(
227                "No points found within effective range".to_string(),
228            ));
229        }
230
231        // Compute weights
232        let weights = self.compute_weights(&distances)?;
233
234        // Create basis functions for these points
235        let basis_functions = self.create_basis_functions(&indices, x)?;
236
237        // Weighted least squares solution
238        let result = self.solve_weighted_least_squares(&indices, &weights, &basis_functions, x)?;
239
240        Ok(result)
241    }
242
243    /// Evaluate the MLS approximation at multiple points
244    ///
245    /// # Arguments
246    ///
247    /// * `points` - Query points with shape (n_points, n_dims)
248    ///
249    /// # Returns
250    ///
251    /// Interpolated values at the query points
252    pub fn evaluate_multi(&self, points: &ArrayView2<F>) -> InterpolateResult<Array1<F>> {
253        // Check dimensions
254        if points.shape()[1] != self.points.shape()[1] {
255            return Err(InterpolateError::DimensionMismatch(
256                "Query points dimension must match training points".to_string(),
257            ));
258        }
259
260        let n_points = points.shape()[0];
261        let mut results = Array1::zeros(n_points);
262
263        // Evaluate at each point
264        for i in 0..n_points {
265            let point = points.slice(scirs2_core::ndarray::s![i, ..]);
266            results[i] = self.evaluate(&point)?;
267        }
268
269        Ok(results)
270    }
271
272    /// Find points to use for local fit
273    ///
274    /// Returns indices of points to use and their distances to the query point
275    fn find_relevant_points(&self, x: &ArrayView1<F>) -> InterpolateResult<(Vec<usize>, Vec<F>)> {
276        let n_points = self.points.shape()[0];
277        let n_dims = self.points.shape()[1];
278
279        // Compute squared distances
280        let mut distances = Vec::with_capacity(n_points);
281        for i in 0..n_points {
282            let mut d_squared = F::zero();
283            for j in 0..n_dims {
284                let diff = x[j] - self.points[[i, j]];
285                d_squared = d_squared + diff * diff;
286            }
287            let dist = d_squared.sqrt();
288            distances.push((i, dist));
289        }
290
291        // Sort by distance
292        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
293
294        // Apply max_points limit if specified
295        let limit = match self.max_points {
296            Some(limit) => std::cmp::min(limit, n_points),
297            None => n_points,
298        };
299
300        // Filter out points with zero weight (if using compactly supported weight function)
301        let effective_radius = match self.weight_fn {
302            WeightFunction::WendlandC2 | WeightFunction::CubicSpline => self.bandwidth,
303            _ => F::infinity(),
304        };
305
306        let mut indices = Vec::new();
307        let mut dist_values = Vec::new();
308
309        for &(idx, dist) in distances.iter().take(limit) {
310            if dist <= effective_radius {
311                indices.push(idx);
312                dist_values.push(dist);
313            }
314        }
315
316        // Ensure we have enough points for the basis
317        let min_points = match self.basis {
318            PolynomialBasis::Constant => 1,
319            PolynomialBasis::Linear => n_dims + 1,
320            PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
321        };
322
323        if indices.len() < min_points {
324            // If not enough points with compact support, take the closest ones
325            indices = distances
326                .iter()
327                .take(min_points)
328                .map(|&(idx, _)| idx)
329                .collect();
330            dist_values = distances
331                .iter()
332                .take(min_points)
333                .map(|&(_, dist)| dist)
334                .collect();
335        }
336
337        Ok((indices, dist_values))
338    }
339
340    /// Compute weights for the given distances
341    fn compute_weights(&self, distances: &[F]) -> InterpolateResult<Array1<F>> {
342        let n = distances.len();
343        let mut weights = Array1::zeros(n);
344
345        for (i, &d) in distances.iter().enumerate() {
346            // Normalize distance by bandwidth
347            let r = d / self.bandwidth;
348
349            // Compute weight based on the chosen weight function
350            let weight = match self.weight_fn {
351                WeightFunction::Gaussian => (-r * r).exp(),
352                WeightFunction::WendlandC2 => {
353                    if r < F::one() {
354                        let t = F::one() - r;
355                        let factor = F::from_f64(4.0).unwrap() * r + F::one();
356                        t.powi(4) * factor
357                    } else {
358                        F::zero()
359                    }
360                }
361                WeightFunction::InverseDistance => F::one() / (self.epsilon + r * r),
362                WeightFunction::CubicSpline => {
363                    if r < F::from_f64(1.0 / 3.0).unwrap() {
364                        let r2 = r * r;
365                        let r3 = r2 * r;
366                        F::from_f64(2.0 / 3.0).unwrap() - F::from_f64(9.0).unwrap() * r2
367                            + F::from_f64(19.0).unwrap() * r3
368                    } else if r < F::one() {
369                        let t = F::from_f64(2.0).unwrap() - F::from_f64(3.0).unwrap() * r;
370                        F::from_f64(1.0 / 3.0).unwrap() * t.powi(3)
371                    } else {
372                        F::zero()
373                    }
374                }
375            };
376
377            weights[i] = weight;
378        }
379
380        // Normalize weights to sum to 1 for numerical stability
381        let sum = weights.sum();
382        if sum > F::zero() {
383            weights.mapv_inplace(|w| w / sum);
384        } else {
385            // If all weights are zero (shouldn't happen), use equal weights
386            weights.fill(F::from_f64(1.0 / n as f64).unwrap());
387        }
388
389        Ok(weights)
390    }
391
392    /// Create basis functions for the given points
393    fn create_basis_functions(
394        &self,
395        indices: &[usize],
396        x: &ArrayView1<F>,
397    ) -> InterpolateResult<Array2<F>> {
398        let n_points = indices.len();
399        let n_dims = x.len();
400
401        // Determine number of basis functions
402        let n_basis = match self.basis {
403            PolynomialBasis::Constant => 1,
404            PolynomialBasis::Linear => n_dims + 1,
405            PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
406        };
407
408        let mut basis = Array2::zeros((n_points, n_basis));
409
410        // Fill basis functions for each point
411        for (i, &idx) in indices.iter().enumerate() {
412            let point = self.points.row(idx);
413            let mut col = 0;
414
415            // Constant term
416            basis[[i, col]] = F::one();
417            col += 1;
418
419            if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
420                // Linear terms
421                for j in 0..n_dims {
422                    basis[[i, col]] = point[j];
423                    col += 1;
424                }
425            }
426
427            if self.basis == PolynomialBasis::Quadratic {
428                // Quadratic terms
429                for j in 0..n_dims {
430                    for k in j..n_dims {
431                        basis[[i, col]] = point[j] * point[k];
432                        col += 1;
433                    }
434                }
435            }
436        }
437
438        Ok(basis)
439    }
440
441    /// Create basis functions for evaluation at the query point
442    fn create_query_basis(&self, x: &ArrayView1<F>) -> InterpolateResult<Array1<F>> {
443        let n_dims = x.len();
444
445        // Determine number of basis functions
446        let n_basis = match self.basis {
447            PolynomialBasis::Constant => 1,
448            PolynomialBasis::Linear => n_dims + 1,
449            PolynomialBasis::Quadratic => ((n_dims + 1) * (n_dims + 2)) / 2,
450        };
451
452        let mut basis = Array1::zeros(n_basis);
453        let mut col = 0;
454
455        // Constant term
456        basis[col] = F::one();
457        col += 1;
458
459        if self.basis == PolynomialBasis::Linear || self.basis == PolynomialBasis::Quadratic {
460            // Linear terms
461            for j in 0..n_dims {
462                basis[col] = x[j];
463                col += 1;
464            }
465        }
466
467        if self.basis == PolynomialBasis::Quadratic {
468            // Quadratic terms
469            for j in 0..n_dims {
470                for k in j..n_dims {
471                    basis[col] = x[j] * x[k];
472                    col += 1;
473                }
474            }
475        }
476
477        Ok(basis)
478    }
479
480    /// Solve the weighted least squares problem
481    fn solve_weighted_least_squares(
482        &self,
483        indices: &[usize],
484        weights: &Array1<F>,
485        basis: &Array2<F>,
486        x: &ArrayView1<F>,
487    ) -> InterpolateResult<F> {
488        let n_points = indices.len();
489        let n_basis = basis.shape()[1];
490
491        // Create the weighted basis matrix and target vector
492        let mut w_basis = Array2::zeros((n_points, n_basis));
493        let mut w_values = Array1::zeros(n_points);
494
495        for i in 0..n_points {
496            let sqrt_w = weights[i].sqrt();
497            for j in 0..n_basis {
498                w_basis[[i, j]] = basis[[i, j]] * sqrt_w;
499            }
500            w_values[i] = self.values[indices[i]] * sqrt_w;
501        }
502
503        // Solve the least squares problem: (B'B)c = B'y
504        #[cfg(feature = "linalg")]
505        let btb = w_basis.t().dot(&w_basis);
506        #[cfg(not(feature = "linalg"))]
507        let _btb = w_basis.t().dot(&w_basis);
508        #[allow(unused_variables)]
509        let bty = w_basis.t().dot(&w_values);
510
511        // Solve the system for coefficients
512        #[cfg(feature = "linalg")]
513        let coeffs = {
514            use scirs2_linalg::solve;
515            let btb_f64 = btb.mapv(|x| x.to_f64().unwrap());
516            let bty_f64 = bty.mapv(|x| x.to_f64().unwrap());
517            match solve(&btb_f64.view(), &bty_f64.view(), None) {
518                Ok(c) => c.mapv(|x| F::from_f64(x).unwrap()),
519                Err(_) => {
520                    // Fallback: use local mean for numerical stability
521                    let mut mean = F::zero();
522                    let mut sum_weights = F::zero();
523                    for (i, &idx) in indices.iter().enumerate() {
524                        mean = mean + weights[i] * self.values[idx];
525                        sum_weights = sum_weights + weights[i];
526                    }
527
528                    if sum_weights > F::zero() {
529                        // For the fallback, we'll create a coefficient vector with just the mean
530                        // as the constant term and zeros elsewhere
531                        let mut fallback_coeffs = Array1::zeros(bty.len());
532                        fallback_coeffs[0] = mean / sum_weights;
533                        fallback_coeffs
534                    } else {
535                        return Err(InterpolateError::ComputationError(
536                            "Failed to solve weighted least squares system".to_string(),
537                        ));
538                    }
539                }
540            }
541        };
542
543        #[cfg(not(feature = "linalg"))]
544        let coeffs = {
545            // Fallback implementation when linalg is not available
546            // Simple diagonal approximation
547            let mut result = Array1::zeros(bty.len());
548
549            // Use local mean for constant term
550            let mut mean = F::zero();
551            let mut sum_weights = F::zero();
552            for (i, &idx) in indices.iter().enumerate() {
553                mean = mean + weights[i] * self.values[idx];
554                sum_weights = sum_weights + weights[i];
555            }
556
557            if sum_weights > F::zero() {
558                result[0] = mean / sum_weights;
559            }
560
561            result
562        };
563
564        // Evaluate at the query point by creating the basis for it
565        let query_basis = self.create_query_basis(x)?;
566        let result = query_basis.dot(&coeffs);
567
568        Ok(result)
569    }
570
571    /// Get the weight function used by this MLS interpolator
572    pub fn weight_fn(&self) -> WeightFunction {
573        self.weight_fn
574    }
575
576    /// Get the bandwidth parameter used by this MLS interpolator
577    pub fn bandwidth(&self) -> F {
578        self.bandwidth
579    }
580
581    /// Get the points used by this MLS interpolator
582    pub fn points(&self) -> &Array2<F> {
583        &self.points
584    }
585
586    /// Get the values used by this MLS interpolator
587    pub fn values(&self) -> &Array1<F> {
588        &self.values
589    }
590
591    /// Get the basis type used by this MLS interpolator
592    pub fn basis(&self) -> PolynomialBasis {
593        self.basis
594    }
595
596    /// Get the maximum points setting used by this MLS interpolator
597    pub fn max_points(&self) -> Option<usize> {
598        self.max_points
599    }
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605    use approx::assert_abs_diff_eq;
606    use scirs2_core::ndarray::array;
607
608    #[test]
609    fn test_mls_constant_basis() {
610        // Simple test with 2D data and constant basis
611        let points =
612            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
613
614        // Simple plane: z = x + y
615        let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
616
617        let mls = MovingLeastSquares::new(
618            points,
619            values,
620            WeightFunction::Gaussian,
621            PolynomialBasis::Constant,
622            0.5,
623        )
624        .unwrap();
625
626        // Test at center point - should be close to average of all values (1.0)
627        let center = array![0.5, 0.5];
628        let val = mls.evaluate(&center.view()).unwrap();
629
630        assert_abs_diff_eq!(val, 1.0, epsilon = 0.1);
631    }
632
633    #[test]
634    fn test_mls_linear_basis() {
635        // Simple test with 2D data and linear basis
636        let points =
637            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
638
639        // Simple plane: z = x + y
640        let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0]);
641
642        let mls = MovingLeastSquares::new(
643            points,
644            values,
645            WeightFunction::Gaussian,
646            PolynomialBasis::Linear,
647            1.0,
648        )
649        .unwrap();
650
651        // With linear basis, should be able to reproduce the plane equation
652        let test_points = Array2::from_shape_vec(
653            (5, 2),
654            vec![
655                0.5, 0.5, // Should be exactly 1.0
656                0.25, 0.25, // Should be exactly 0.5
657                0.75, 0.25, // Should be exactly 1.0
658                0.25, 0.75, // Should be exactly 1.0
659                0.75, 0.75, // Should be exactly 1.5
660            ],
661        )
662        .unwrap();
663
664        let expected = Array1::from_vec(vec![1.0, 0.5, 1.0, 1.0, 1.5]);
665        let results = mls.evaluate_multi(&test_points.view()).unwrap();
666
667        // Allow some numerical error, but should be close to exact values
668        for (result, expect) in results.iter().zip(expected.iter()) {
669            assert_abs_diff_eq!(result, expect, epsilon = 0.5);
670        }
671    }
672
673    #[test]
674    fn test_different_weight_functions() {
675        // Simple test with 2D data - well-spaced points to avoid singularities
676        let points = Array2::from_shape_vec(
677            (6, 2),
678            vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.3, 0.3, 0.7, 0.7],
679        )
680        .unwrap();
681
682        // Simple function: z = x + y (linear function for better numerical stability)
683        let values = Array1::from_vec(vec![0.0, 1.0, 1.0, 2.0, 0.6, 1.4]);
684
685        // Test with different weight functions
686        let weight_fns = [WeightFunction::Gaussian, WeightFunction::InverseDistance];
687
688        let query = array![0.5, 0.5];
689        let expected = 0.5 + 0.5; // 1.0 (linear function: z = x + y)
690
691        for &weight_fn in &weight_fns {
692            let mls = MovingLeastSquares::new(
693                points.clone(),
694                values.clone(),
695                weight_fn,
696                PolynomialBasis::Linear, // Use linear basis for better stability
697                2.0,                     // Large bandwidth to include all points
698            )
699            .unwrap();
700
701            let result = mls.evaluate(&query.view());
702
703            match result {
704                Ok(val) => {
705                    if val.is_finite() {
706                        // Allow reasonable error for MLS approximation
707                        assert!((val - expected).abs() < 0.5,
708                               "Weight function {:?}: result {:.6} differs too much from expected {:.6}", 
709                               weight_fn, val, expected);
710                    } else {
711                        panic!(
712                            "Weight function {:?} produced non-finite result: {}",
713                            weight_fn, val
714                        );
715                    }
716                }
717                Err(e) => {
718                    panic!("Weight function {:?} failed with error: {}", weight_fn, e);
719                }
720            }
721        }
722    }
723}