scirs2_spatial/
kriging.rs

1//! Kriging interpolation methods
2//!
3//! This module provides implementations of Kriging, also known as Gaussian process regression,
4//! which is a method of spatial interpolation based on the theory of regionalized variables.
5//! Kriging provides the best linear unbiased estimator (BLUE) for spatial data.
6//!
7//! # Theory
8//!
9//! Kriging assumes that the data follows a spatial stochastic process and uses
10//! a variogram or covariance function to model spatial correlation. The main types
11//! of Kriging implemented are:
12//!
13//! - **Simple Kriging**: Assumes a known constant mean
14//! - **Ordinary Kriging**: Estimates the mean locally (most common)
15//! - **Universal Kriging**: Models trend with basis functions
16//!
17//! The Kriging prediction at location x₀ is:
18//! Z*(x₀) = Σᵢ λᵢ Z(xᵢ)
19//!
20//! where λᵢ are weights determined by solving the Kriging system.
21//!
22//! # Examples
23//!
24//! ```
25//! use scirs2_spatial::kriging::{OrdinaryKriging, VariogramModel};
26//! use scirs2_core::ndarray::array;
27//!
28//! // Sample data points (x, y, z)
29//! let points = array![
30//!     [0.0, 0.0],
31//!     [1.0, 0.0],
32//!     [0.0, 1.0],
33//!     [1.0, 1.0],
34//!     [0.5, 0.5]
35//! ];
36//!
37//! let values = array![1.0, 2.0, 3.0, 4.0, 2.5];
38//!
39//! // Create Kriging interpolator with spherical variogram
40//! let variogram = VariogramModel::spherical(1.0, 0.1, 0.0);
41//! let kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
42//!
43//! // Interpolate at new location
44//! let prediction = kriging.predict(&[0.25, 0.25]).unwrap();
45//! println!("Predicted value: {:.3}", prediction.value);
46//! println!("Prediction variance: {:.3}", prediction.variance);
47//! ```
48
49use crate::error::{SpatialError, SpatialResult};
50use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
51
52/// Variogram model types for Kriging
53#[derive(Debug, Clone)]
54pub enum VariogramModel {
55    /// Spherical variogram: γ(h) = c₀ + c₁[1.5(h/a) - 0.5(h/a)³] for h ≤ a, γ(h) = c₀ + c₁ for h > a
56    Spherical { range: f64, sill: f64, nugget: f64 },
57    /// Exponential variogram: γ(h) = c₀ + c₁[1 - exp(-h/a)]
58    Exponential { range: f64, sill: f64, nugget: f64 },
59    /// Gaussian variogram: γ(h) = c₀ + c₁[1 - exp(-(h/a)²)]
60    Gaussian { range: f64, sill: f64, nugget: f64 },
61    /// Linear variogram: γ(h) = c₀ + c₁h (unbounded)
62    Linear { slope: f64, nugget: f64 },
63    /// Power variogram: γ(h) = c₀ + c₁h^α for 0 < α < 2
64    Power {
65        coefficient: f64,
66        exponent: f64,
67        nugget: f64,
68    },
69    /// Matérn variogram with parameter ν
70    Matern {
71        range: f64,
72        sill: f64,
73        nugget: f64,
74        nu: f64,
75    },
76}
77
78impl VariogramModel {
79    /// Create a spherical variogram model
80    ///
81    /// # Arguments
82    /// * `range` - Range parameter (distance where correlation becomes negligible)
83    /// * `sill` - Sill parameter (maximum variance)
84    /// * `nugget` - Nugget parameter (variance at zero distance)
85    pub fn spherical(range: f64, sill: f64, nugget: f64) -> Self {
86        Self::Spherical {
87            range,
88            sill,
89            nugget,
90        }
91    }
92
93    /// Create an exponential variogram model
94    pub fn exponential(range: f64, sill: f64, nugget: f64) -> Self {
95        Self::Exponential {
96            range,
97            sill,
98            nugget,
99        }
100    }
101
102    /// Create a Gaussian variogram model
103    pub fn gaussian(range: f64, sill: f64, nugget: f64) -> Self {
104        Self::Gaussian {
105            range,
106            sill,
107            nugget,
108        }
109    }
110
111    /// Create a linear variogram model
112    pub fn linear(slope: f64, nugget: f64) -> Self {
113        Self::Linear { slope, nugget }
114    }
115
116    /// Create a power variogram model
117    pub fn power(coefficient: f64, exponent: f64, nugget: f64) -> Self {
118        Self::Power {
119            coefficient,
120            exponent,
121            nugget,
122        }
123    }
124
125    /// Create a Matérn variogram model
126    pub fn matern(range: f64, sill: f64, nugget: f64, nu: f64) -> Self {
127        Self::Matern {
128            range,
129            sill,
130            nugget,
131            nu,
132        }
133    }
134
135    /// Evaluate the variogram at distance h
136    ///
137    /// # Arguments
138    /// * `h` - Distance
139    ///
140    /// # Returns
141    /// * Variogram value
142    pub fn evaluate(&self, h: f64) -> f64 {
143        if h < 0.0 {
144            return 0.0;
145        }
146
147        if h.abs() < 1e-15 {
148            return match self {
149                Self::Spherical { nugget, .. }
150                | Self::Exponential { nugget, .. }
151                | Self::Gaussian { nugget, .. }
152                | Self::Linear { nugget, .. }
153                | Self::Power { nugget, .. }
154                | Self::Matern { nugget, .. } => *nugget,
155            };
156        }
157
158        match self {
159            Self::Spherical {
160                range,
161                sill,
162                nugget,
163            } => {
164                if h >= *range {
165                    nugget + sill
166                } else {
167                    let h_r = h / range;
168                    nugget + sill * (1.5 * h_r - 0.5 * h_r.powi(3))
169                }
170            }
171            Self::Exponential {
172                range,
173                sill,
174                nugget,
175            } => nugget + sill * (1.0 - (-h / range).exp()),
176            Self::Gaussian {
177                range,
178                sill,
179                nugget,
180            } => nugget + sill * (1.0 - (-(h / range).powi(2)).exp()),
181            Self::Linear { slope, nugget } => nugget + slope * h,
182            Self::Power {
183                coefficient,
184                exponent,
185                nugget,
186            } => nugget + coefficient * h.powf(*exponent),
187            Self::Matern {
188                range,
189                sill,
190                nugget,
191                nu,
192            } => {
193                let h_r = h / range;
194                if h_r < 1e-10 {
195                    *nugget
196                } else {
197                    // Simplified Matérn for common values of ν
198                    let matern_val = if (nu - 0.5).abs() < 1e-10 {
199                        // ν = 0.5: exponential
200                        1.0 - (-h_r).exp()
201                    } else if (nu - 1.5).abs() < 1e-10 {
202                        // ν = 1.5
203                        (1.0 + 3.0_f64.sqrt() * h_r) * (-3.0_f64.sqrt() * h_r).exp()
204                    } else if (nu - 2.5).abs() < 1e-10 {
205                        // ν = 2.5
206                        (1.0 + 5.0_f64.sqrt() * h_r + 5.0 * h_r.powi(2) / 3.0)
207                            * (-5.0_f64.sqrt() * h_r).exp()
208                    } else {
209                        // General case approximation
210                        1.0 - ((-h_r).exp() * (1.0 + h_r))
211                    };
212                    nugget + sill * (1.0 - matern_val)
213                }
214            }
215        }
216    }
217
218    /// Get the effective range of the variogram
219    pub fn effective_range(&self) -> f64 {
220        match self {
221            Self::Spherical { range, .. } => *range,
222            Self::Exponential { range, .. } => 3.0 * range, // Practical range
223            Self::Gaussian { range, .. } => 3.0_f64.sqrt() * range, // Practical range
224            Self::Linear { .. } => f64::INFINITY,
225            Self::Power { .. } => f64::INFINITY,
226            Self::Matern { range, .. } => 3.0 * range,
227        }
228    }
229
230    /// Get the sill (maximum variance) of the variogram
231    pub fn sill(&self) -> f64 {
232        match self {
233            Self::Spherical { sill, nugget, .. }
234            | Self::Exponential { sill, nugget, .. }
235            | Self::Gaussian { sill, nugget, .. }
236            | Self::Matern { sill, nugget, .. } => sill + nugget,
237            Self::Linear { .. } | Self::Power { .. } => f64::INFINITY,
238        }
239    }
240
241    /// Get the nugget effect
242    pub fn nugget(&self) -> f64 {
243        match self {
244            Self::Spherical { nugget, .. }
245            | Self::Exponential { nugget, .. }
246            | Self::Gaussian { nugget, .. }
247            | Self::Linear { nugget, .. }
248            | Self::Power { nugget, .. }
249            | Self::Matern { nugget, .. } => *nugget,
250        }
251    }
252}
253
254/// Prediction result from Kriging interpolation
255#[derive(Debug, Clone)]
256pub struct KrigingPrediction {
257    /// Predicted value
258    pub value: f64,
259    /// Prediction variance (uncertainty)
260    pub variance: f64,
261    /// Weights used in the prediction
262    pub weights: Array1<f64>,
263}
264
265/// Ordinary Kriging interpolator
266///
267/// Ordinary Kriging assumes the mean is unknown but constant within a local neighborhood.
268/// It provides the Best Linear Unbiased Estimator (BLUE) for spatial data.
269#[derive(Debug, Clone)]
270pub struct OrdinaryKriging {
271    /// Data point locations
272    points: Array2<f64>,
273    /// Data values at points
274    values: Array1<f64>,
275    /// Variogram model
276    variogram: VariogramModel,
277    /// Number of data points
278    n_points: usize,
279    /// Dimension of space
280    ndim: usize,
281    /// Precomputed covariance matrix (inverse)
282    cov_matrix_inv: Option<Array2<f64>>,
283}
284
285impl OrdinaryKriging {
286    /// Create a new Ordinary Kriging interpolator
287    ///
288    /// # Arguments
289    /// * `points` - Array of point coordinates, shape (n_points, ndim)
290    /// * `values` - Array of values at points, shape (n_points,)
291    /// * `variogram` - Variogram model to use
292    ///
293    /// # Returns
294    /// * New OrdinaryKriging instance
295    ///
296    /// # Examples
297    ///
298    /// ```
299    /// use scirs2_spatial::kriging::{OrdinaryKriging, VariogramModel};
300    /// use scirs2_core::ndarray::array;
301    ///
302    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
303    /// let values = array![1.0, 2.0, 3.0];
304    /// let variogram = VariogramModel::spherical(1.0, 0.5, 0.1);
305    ///
306    /// let kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
307    /// ```
308    pub fn new(
309        points: &ArrayView2<'_, f64>,
310        values: &ArrayView1<f64>,
311        variogram: VariogramModel,
312    ) -> SpatialResult<Self> {
313        let n_points = points.nrows();
314        let ndim = points.ncols();
315
316        if values.len() != n_points {
317            return Err(SpatialError::ValueError(
318                "Number of values must match number of points".to_string(),
319            ));
320        }
321
322        if n_points < 3 {
323            return Err(SpatialError::ValueError(
324                "Need at least 3 points for Kriging".to_string(),
325            ));
326        }
327
328        if !(1..=3).contains(&ndim) {
329            return Err(SpatialError::ValueError(
330                "Kriging supports 1D, 2D, and 3D points only".to_string(),
331            ));
332        }
333
334        Ok(Self {
335            points: points.to_owned(),
336            values: values.to_owned(),
337            variogram,
338            n_points,
339            ndim,
340            cov_matrix_inv: None,
341        })
342    }
343
344    /// Fit the Kriging model by precomputing the covariance matrix inverse
345    ///
346    /// This step is optional but recommended for multiple predictions
347    /// as it avoids recomputing the matrix inverse each time.
348    pub fn fit(&mut self) -> SpatialResult<()> {
349        let cov_matrix = self.build_covariance_matrix()?;
350        let inv_matrix = OrdinaryKriging::invert_matrix(&cov_matrix)?;
351        self.cov_matrix_inv = Some(inv_matrix);
352        Ok(())
353    }
354
355    /// Predict value at a new location
356    ///
357    /// # Arguments
358    /// * `location` - Point where to predict, shape (ndim,)
359    ///
360    /// # Returns
361    /// * KrigingPrediction with value, variance, and weights
362    ///
363    /// # Examples
364    ///
365    /// ```
366    /// use scirs2_spatial::kriging::{OrdinaryKriging, VariogramModel};
367    /// use scirs2_core::ndarray::array;
368    ///
369    /// let points = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
370    /// let values = array![1.0, 2.0, 3.0, 4.0];
371    /// let variogram = VariogramModel::spherical(1.5, 1.0, 0.1);
372    ///
373    /// let mut kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
374    /// kriging.fit().unwrap();
375    ///
376    /// let prediction = kriging.predict(&[0.5, 0.5]).unwrap();
377    /// println!("Predicted: {:.3} ± {:.3}", prediction.value, prediction.variance.sqrt());
378    /// ```
379    pub fn predict(&self, location: &[f64]) -> SpatialResult<KrigingPrediction> {
380        if location.len() != self.ndim {
381            return Err(SpatialError::ValueError(
382                "Location dimension must match data dimension".to_string(),
383            ));
384        }
385
386        // Build covariance matrix if not precomputed
387        let cov_inv = if let Some(ref inv) = self.cov_matrix_inv {
388            inv.clone()
389        } else {
390            let cov_matrix = self.build_covariance_matrix()?;
391            OrdinaryKriging::invert_matrix(&cov_matrix)?
392        };
393
394        // Build covariance vector between new _location and data points
395        let mut cov_vector = Array1::zeros(self.n_points + 1);
396        for i in 0..self.n_points {
397            let dist = OrdinaryKriging::distance(location, &self.points.row(i).to_vec());
398            cov_vector[i] = self.variogram.sill() - self.variogram.evaluate(dist);
399        }
400        cov_vector[self.n_points] = 1.0; // Lagrange multiplier for unbiasedness constraint
401
402        // Solve for weights
403        let weights_extended = cov_inv.dot(&cov_vector);
404        let weights = weights_extended.slice(s![..self.n_points]).to_owned();
405
406        // Calculate prediction
407        let value = weights.dot(&self.values);
408
409        // Calculate prediction variance
410        let variance = self.variogram.sill() - weights_extended.dot(&cov_vector);
411
412        Ok(KrigingPrediction {
413            value,
414            variance: variance.max(0.0), // Ensure non-negative variance
415            weights,
416        })
417    }
418
419    /// Predict values at multiple locations efficiently
420    ///
421    /// # Arguments
422    /// * `locations` - Array of locations, shape (n_locations, ndim)
423    ///
424    /// # Returns
425    /// * Vector of KrigingPrediction results
426    pub fn predict_batch(
427        &self,
428        locations: &ArrayView2<'_, f64>,
429    ) -> SpatialResult<Vec<KrigingPrediction>> {
430        if locations.ncols() != self.ndim {
431            return Err(SpatialError::ValueError(
432                "Location dimension must match data dimension".to_string(),
433            ));
434        }
435
436        // Precompute covariance matrix inverse if not done
437        let cov_inv = if let Some(ref inv) = self.cov_matrix_inv {
438            inv.clone()
439        } else {
440            let cov_matrix = self.build_covariance_matrix()?;
441            OrdinaryKriging::invert_matrix(&cov_matrix)?
442        };
443
444        let mut predictions = Vec::with_capacity(locations.nrows());
445
446        for location_row in locations.outer_iter() {
447            let location: Vec<f64> = location_row.to_vec();
448
449            // Build covariance vector
450            let mut cov_vector = Array1::zeros(self.n_points + 1);
451            for i in 0..self.n_points {
452                let dist = OrdinaryKriging::distance(&location, &self.points.row(i).to_vec());
453                cov_vector[i] = self.variogram.sill() - self.variogram.evaluate(dist);
454            }
455            cov_vector[self.n_points] = 1.0;
456
457            // Solve for weights
458            let weights_extended = cov_inv.dot(&cov_vector);
459            let weights = weights_extended.slice(s![..self.n_points]).to_owned();
460
461            // Calculate prediction and variance
462            let value = weights.dot(&self.values);
463            let variance = (self.variogram.sill() - weights_extended.dot(&cov_vector)).max(0.0);
464
465            predictions.push(KrigingPrediction {
466                value,
467                variance,
468                weights,
469            });
470        }
471
472        Ok(predictions)
473    }
474
475    /// Build the covariance matrix for the Kriging system
476    fn build_covariance_matrix(&self) -> SpatialResult<Array2<f64>> {
477        let size = self.n_points + 1;
478        let mut matrix = Array2::zeros((size, size));
479
480        // Fill covariance values
481        for i in 0..self.n_points {
482            for j in 0..self.n_points {
483                let dist = if i == j {
484                    0.0
485                } else {
486                    OrdinaryKriging::distance(
487                        &self.points.row(i).to_vec(),
488                        &self.points.row(j).to_vec(),
489                    )
490                };
491                // Covariance = Sill - Variogram
492                matrix[[i, j]] = self.variogram.sill() - self.variogram.evaluate(dist);
493            }
494        }
495
496        // Unbiasedness constraint (Lagrange multipliers)
497        for i in 0..self.n_points {
498            matrix[[i, self.n_points]] = 1.0;
499            matrix[[self.n_points, i]] = 1.0;
500        }
501        matrix[[self.n_points, self.n_points]] = 0.0;
502
503        Ok(matrix)
504    }
505
506    /// Compute Euclidean distance between two points
507    #[allow(dead_code)]
508    fn distance(p1: &[f64], p2: &[f64]) -> f64 {
509        p1.iter()
510            .zip(p2.iter())
511            .map(|(a, b)| (a - b).powi(2))
512            .sum::<f64>()
513            .sqrt()
514    }
515
516    /// Invert a matrix using Gaussian elimination with partial pivoting
517    fn invert_matrix(matrix: &Array2<f64>) -> SpatialResult<Array2<f64>> {
518        let n = matrix.nrows();
519        if n != matrix.ncols() {
520            return Err(SpatialError::ComputationError(
521                "Matrix must be square for inversion".to_string(),
522            ));
523        }
524
525        // Create augmented _matrix [A | I]
526        let mut aug = Array2::zeros((n, 2 * n));
527
528        // Fill A part
529        for i in 0..n {
530            for j in 0..n {
531                aug[[i, j]] = matrix[[i, j]];
532            }
533        }
534
535        // Fill identity part
536        for i in 0..n {
537            aug[[i, n + i]] = 1.0;
538        }
539
540        // Gaussian elimination with partial pivoting
541        for i in 0..n {
542            // Find pivot
543            let mut max_row = i;
544            for k in (i + 1)..n {
545                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
546                    max_row = k;
547                }
548            }
549
550            // Swap rows
551            if max_row != i {
552                for j in 0..(2 * n) {
553                    let temp = aug[[i, j]];
554                    aug[[i, j]] = aug[[max_row, j]];
555                    aug[[max_row, j]] = temp;
556                }
557            }
558
559            // Check for singular _matrix
560            if aug[[i, i]].abs() < 1e-12 {
561                return Err(SpatialError::ComputationError(
562                    "Matrix is singular (not invertible)".to_string(),
563                ));
564            }
565
566            // Scale pivot row
567            let pivot = aug[[i, i]];
568            for j in 0..(2 * n) {
569                aug[[i, j]] /= pivot;
570            }
571
572            // Eliminate column
573            for k in 0..n {
574                if k != i {
575                    let factor = aug[[k, i]];
576                    for j in 0..(2 * n) {
577                        aug[[k, j]] -= factor * aug[[i, j]];
578                    }
579                }
580            }
581        }
582
583        // Extract inverse _matrix
584        let mut inverse = Array2::zeros((n, n));
585        for i in 0..n {
586            for j in 0..n {
587                inverse[[i, j]] = aug[[i, n + j]];
588            }
589        }
590
591        Ok(inverse)
592    }
593
594    /// Get the variogram model
595    pub fn variogram(&self) -> &VariogramModel {
596        &self.variogram
597    }
598
599    /// Get the number of data points
600    pub fn n_points(&self) -> usize {
601        self.n_points
602    }
603
604    /// Get the data points
605    pub fn points(&self) -> &Array2<f64> {
606        &self.points
607    }
608
609    /// Get the data values
610    pub fn values(&self) -> &Array1<f64> {
611        &self.values
612    }
613
614    /// Cross-validation: leave-one-out prediction errors
615    ///
616    /// # Returns
617    /// * Array of prediction errors (predicted - actual)
618    pub fn cross_validate(&self) -> SpatialResult<Array1<f64>> {
619        let mut errors = Array1::zeros(self.n_points);
620
621        for i in 0..self.n_points {
622            // Create subset without point i
623            let mut subset_points = Array2::zeros((self.n_points - 1, self.ndim));
624            let mut subset_values = Array1::zeros(self.n_points - 1);
625
626            let mut idx = 0;
627            for j in 0..self.n_points {
628                if j != i {
629                    subset_points.row_mut(idx).assign(&self.points.row(j));
630                    subset_values[idx] = self.values[j];
631                    idx += 1;
632                }
633            }
634
635            // Create Kriging model without point i
636            let subset_kriging = OrdinaryKriging::new(
637                &subset_points.view(),
638                &subset_values.view(),
639                self.variogram.clone(),
640            )?;
641
642            // Predict at point i
643            let location: Vec<f64> = self.points.row(i).to_vec();
644            let prediction = subset_kriging.predict(&location)?;
645
646            errors[i] = prediction.value - self.values[i];
647        }
648
649        Ok(errors)
650    }
651}
652
653/// Simple Kriging interpolator
654///
655/// Simple Kriging assumes a known constant mean value.
656#[derive(Debug, Clone)]
657pub struct SimpleKriging {
658    /// Data point locations
659    points: Array2<f64>,
660    /// Data values at points
661    values: Array1<f64>,
662    /// Known mean value
663    mean: f64,
664    /// Variogram model
665    variogram: VariogramModel,
666    /// Number of data points
667    n_points: usize,
668    /// Dimension of space
669    ndim: usize,
670}
671
672impl SimpleKriging {
673    /// Create a new Simple Kriging interpolator
674    ///
675    /// # Arguments
676    /// * `points` - Array of point coordinates
677    /// * `values` - Array of values at points
678    /// * `mean` - Known mean value
679    /// * `variogram` - Variogram model
680    pub fn new(
681        points: &ArrayView2<'_, f64>,
682        values: &ArrayView1<f64>,
683        mean: f64,
684        variogram: VariogramModel,
685    ) -> SpatialResult<Self> {
686        let n_points = points.nrows();
687        let ndim = points.ncols();
688
689        if values.len() != n_points {
690            return Err(SpatialError::ValueError(
691                "Number of values must match number of points".to_string(),
692            ));
693        }
694
695        if n_points < 2 {
696            return Err(SpatialError::ValueError(
697                "Need at least 2 points for Simple Kriging".to_string(),
698            ));
699        }
700
701        Ok(Self {
702            points: points.to_owned(),
703            values: values.to_owned(),
704            mean,
705            variogram,
706            n_points,
707            ndim,
708        })
709    }
710
711    /// Predict value at a new location
712    ///
713    /// # Arguments
714    /// * `location` - Point where to predict
715    ///
716    /// # Returns
717    /// * KrigingPrediction with value, variance, and weights
718    pub fn predict(&self, location: &[f64]) -> SpatialResult<KrigingPrediction> {
719        if location.len() != self.ndim {
720            return Err(SpatialError::ValueError(
721                "Location dimension must match data dimension".to_string(),
722            ));
723        }
724
725        // Build covariance matrix (without Lagrange multiplier)
726        let mut cov_matrix = Array2::zeros((self.n_points, self.n_points));
727        for i in 0..self.n_points {
728            for j in 0..self.n_points {
729                let dist = if i == j {
730                    0.0
731                } else {
732                    OrdinaryKriging::distance(
733                        &self.points.row(i).to_vec(),
734                        &self.points.row(j).to_vec(),
735                    )
736                };
737                cov_matrix[[i, j]] = self.variogram.sill() - self.variogram.evaluate(dist);
738            }
739        }
740
741        // Build covariance vector
742        let mut cov_vector = Array1::zeros(self.n_points);
743        for i in 0..self.n_points {
744            let dist = OrdinaryKriging::distance(location, &self.points.row(i).to_vec());
745            cov_vector[i] = self.variogram.sill() - self.variogram.evaluate(dist);
746        }
747
748        // Solve for weights
749        let weights = SimpleKriging::solve_linear_system(&cov_matrix, &cov_vector)?;
750
751        // Calculate prediction (residuals from mean)
752        let residuals: Array1<f64> = &self.values - self.mean;
753        let value = self.mean + weights.dot(&residuals);
754
755        // Calculate prediction variance
756        let variance = self.variogram.sill() - weights.dot(&cov_vector);
757
758        Ok(KrigingPrediction {
759            value,
760            variance: variance.max(0.0),
761            weights,
762        })
763    }
764
765    /// Compute Euclidean distance between two points
766    #[allow(dead_code)]
767    fn distance(p1: &[f64], p2: &[f64]) -> f64 {
768        p1.iter()
769            .zip(p2.iter())
770            .map(|(a, b)| (a - b).powi(2))
771            .sum::<f64>()
772            .sqrt()
773    }
774
775    /// Solve linear system using Gaussian elimination
776    fn solve_linear_system(a: &Array2<f64>, b: &Array1<f64>) -> SpatialResult<Array1<f64>> {
777        let n = a.nrows();
778
779        // Create augmented matrix
780        let mut aug = Array2::zeros((n, n + 1));
781        for i in 0..n {
782            for j in 0..n {
783                aug[[i, j]] = a[[i, j]];
784            }
785            aug[[i, n]] = b[i];
786        }
787
788        // Forward elimination
789        for i in 0..n {
790            // Find pivot
791            let mut max_row = i;
792            for k in (i + 1)..n {
793                if aug[[k, i]].abs() > aug[[max_row, i]].abs() {
794                    max_row = k;
795                }
796            }
797
798            // Swap rows
799            if max_row != i {
800                for j in 0..(n + 1) {
801                    let temp = aug[[i, j]];
802                    aug[[i, j]] = aug[[max_row, j]];
803                    aug[[max_row, j]] = temp;
804                }
805            }
806
807            // Check for singular matrix
808            if aug[[i, i]].abs() < 1e-12 {
809                return Err(SpatialError::ComputationError(
810                    "Singular matrix in Kriging system".to_string(),
811                ));
812            }
813
814            // Eliminate
815            for k in (i + 1)..n {
816                let factor = aug[[k, i]] / aug[[i, i]];
817                for j in i..(n + 1) {
818                    aug[[k, j]] -= factor * aug[[i, j]];
819                }
820            }
821        }
822
823        // Back substitution
824        let mut solution = Array1::zeros(n);
825        for i in (0..n).rev() {
826            solution[i] = aug[[i, n]];
827            for j in (i + 1)..n {
828                solution[i] -= aug[[i, j]] * solution[j];
829            }
830            solution[i] /= aug[[i, i]];
831        }
832
833        Ok(solution)
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use approx::assert_relative_eq;
841    use scirs2_core::ndarray::arr1;
842
843    #[test]
844    fn test_variogram_models() {
845        let spherical = VariogramModel::spherical(1.0, 0.5, 0.1);
846
847        // At distance 0, should return nugget
848        assert_relative_eq!(spherical.evaluate(0.0), 0.1, epsilon = 1e-10);
849
850        // At range, should approach sill + nugget
851        assert_relative_eq!(spherical.evaluate(1.0), 0.6, epsilon = 1e-10);
852
853        // Beyond range, should be sill + nugget
854        assert_relative_eq!(spherical.evaluate(2.0), 0.6, epsilon = 1e-10);
855
856        let exponential = VariogramModel::exponential(1.0, 0.5, 0.1);
857        assert_relative_eq!(exponential.evaluate(0.0), 0.1, epsilon = 1e-10);
858        assert!(exponential.evaluate(1.0) > 0.1);
859        assert!(exponential.evaluate(1.0) < 0.6);
860
861        let gaussian = VariogramModel::gaussian(1.0, 0.5, 0.1);
862        assert_relative_eq!(gaussian.evaluate(0.0), 0.1, epsilon = 1e-10);
863        assert!(gaussian.evaluate(1.0) > 0.1);
864
865        let linear = VariogramModel::linear(0.2, 0.05);
866        assert_relative_eq!(linear.evaluate(0.0), 0.05, epsilon = 1e-10);
867        assert_relative_eq!(linear.evaluate(1.0), 0.25, epsilon = 1e-10);
868    }
869
870    #[test]
871    fn test_ordinary_kriging_basic() {
872        // Simple 2D case
873        let points =
874            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
875        let values = arr1(&[1.0, 2.0, 3.0, 4.0]);
876
877        let variogram = VariogramModel::spherical(1.5, 1.0, 0.1);
878        let mut kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
879        kriging.fit().unwrap();
880
881        // Predict at center
882        let prediction = kriging.predict(&[0.5, 0.5]).unwrap();
883
884        // Should be close to the average of surrounding points
885        assert!(prediction.value > 1.0);
886        assert!(prediction.value < 4.0);
887        assert!(prediction.variance >= 0.0);
888
889        // Weights should sum to 1 (unbiasedness)
890        let weight_sum: f64 = prediction.weights.sum();
891        assert_relative_eq!(weight_sum, 1.0, epsilon = 1e-10);
892    }
893
894    #[test]
895    fn test_ordinary_kriging_exact_interpolation() {
896        // Test that predictions at data locations are exact
897        let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
898        let values = arr1(&[1.0, 2.0, 3.0]);
899
900        let variogram = VariogramModel::spherical(1.0, 0.5, 0.01); // Small nugget
901        let kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
902
903        // Predict at first data point
904        let prediction = kriging.predict(&[0.0, 0.0]).unwrap();
905        assert_relative_eq!(prediction.value, 1.0, epsilon = 1e-6);
906
907        // Variance should be small at data locations
908        assert!(prediction.variance < 0.1);
909    }
910
911    #[test]
912    fn test_simple_kriging() {
913        let points = Array2::from_shape_vec((3, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0]).unwrap();
914        let values = arr1(&[1.5, 2.5, 3.5]);
915        let mean = 2.0;
916
917        let variogram = VariogramModel::exponential(1.0, 0.8, 0.1);
918        let kriging = SimpleKriging::new(&points.view(), &values.view(), mean, variogram).unwrap();
919
920        let prediction = kriging.predict(&[0.5, 0.5]).unwrap();
921
922        // Should give reasonable prediction
923        assert!(prediction.value > 1.0);
924        assert!(prediction.value < 4.0);
925        assert!(prediction.variance >= 0.0);
926    }
927
928    #[test]
929    fn test_batch_prediction() {
930        let points =
931            Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
932        let values = arr1(&[1.0, 2.0, 3.0, 4.0]);
933
934        let variogram = VariogramModel::spherical(1.5, 1.0, 0.1);
935        let mut kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
936        kriging.fit().unwrap();
937
938        let test_points =
939            Array2::from_shape_vec((3, 2), vec![0.25, 0.25, 0.5, 0.5, 0.75, 0.75]).unwrap();
940
941        let predictions = kriging.predict_batch(&test_points.view()).unwrap();
942
943        assert_eq!(predictions.len(), 3);
944        for prediction in &predictions {
945            assert!(prediction.value > 0.0);
946            assert!(prediction.variance >= 0.0);
947
948            // Weights should sum to 1
949            let weight_sum: f64 = prediction.weights.sum();
950            assert_relative_eq!(weight_sum, 1.0, epsilon = 1e-10);
951        }
952    }
953
954    #[test]
955    fn test_cross_validation() {
956        let points = Array2::from_shape_vec(
957            (5, 2),
958            vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.5, 0.5],
959        )
960        .unwrap();
961        let values = arr1(&[1.0, 2.0, 3.0, 4.0, 2.5]);
962
963        let variogram = VariogramModel::spherical(1.5, 1.0, 0.1);
964        let kriging = OrdinaryKriging::new(&points.view(), &values.view(), variogram).unwrap();
965
966        let errors = kriging.cross_validate().unwrap();
967
968        assert_eq!(errors.len(), 5);
969
970        // Errors should be reasonable (not too large)
971        for &error in errors.iter() {
972            assert!(error.abs() < 5.0); // Reasonable bound for this test case
973        }
974    }
975
976    #[test]
977    fn test_variogram_properties() {
978        let spherical = VariogramModel::spherical(2.0, 1.0, 0.2);
979
980        assert_relative_eq!(spherical.effective_range(), 2.0, epsilon = 1e-10);
981        assert_relative_eq!(spherical.sill(), 1.2, epsilon = 1e-10);
982        assert_relative_eq!(spherical.nugget(), 0.2, epsilon = 1e-10);
983
984        let linear = VariogramModel::linear(0.5, 0.1);
985        assert_eq!(linear.effective_range(), f64::INFINITY);
986        assert_eq!(linear.sill(), f64::INFINITY);
987        assert_relative_eq!(linear.nugget(), 0.1, epsilon = 1e-10);
988    }
989
990    #[test]
991    fn test_error_cases() {
992        let points = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 0.0]).unwrap();
993        let values = arr1(&[1.0, 2.0, 3.0]); // Wrong length
994        let variogram = VariogramModel::spherical(1.0, 0.5, 0.1);
995
996        let result = OrdinaryKriging::new(&points.view(), &values.view(), variogram);
997        assert!(result.is_err());
998
999        // Too few points
1000        let points = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 0.0]).unwrap();
1001        let values = arr1(&[1.0, 2.0]);
1002        let variogram = VariogramModel::spherical(1.0, 0.5, 0.1);
1003
1004        let result = OrdinaryKriging::new(&points.view(), &values.view(), variogram);
1005        assert!(result.is_err());
1006    }
1007}