Skip to main content

sklears_model_selection/
spatial_validation.rs

1//! Spatial cross-validation for geographic and spatial data
2//!
3//! This module provides cross-validation methods that account for spatial
4//! autocorrelation and geographic dependencies in spatial datasets.
5
6use scirs2_core::RngExt;
7use sklears_core::error::{Result, SklearsError};
8use std::collections::HashSet;
9
10/// Spatial coordinates for geographic data
11#[derive(Debug, Clone, Copy)]
12pub struct SpatialCoordinate {
13    pub x: f64,
14    pub y: f64,
15    pub z: Option<f64>, // For 3D spatial data
16}
17
18impl SpatialCoordinate {
19    pub fn new(x: f64, y: f64) -> Self {
20        Self { x, y, z: None }
21    }
22
23    pub fn new_3d(x: f64, y: f64, z: f64) -> Self {
24        Self { x, y, z: Some(z) }
25    }
26
27    /// Calculate Euclidean distance between two coordinates
28    pub fn distance(&self, other: &SpatialCoordinate) -> f64 {
29        let dx = self.x - other.x;
30        let dy = self.y - other.y;
31        let dz = match (self.z, other.z) {
32            (Some(z1), Some(z2)) => z1 - z2,
33            _ => 0.0,
34        };
35        (dx * dx + dy * dy + dz * dz).sqrt()
36    }
37
38    /// Calculate Haversine distance for lat/lon coordinates (in kilometers)
39    pub fn haversine_distance(&self, other: &SpatialCoordinate) -> f64 {
40        const EARTH_RADIUS_KM: f64 = 6371.0;
41
42        let lat1 = self.y.to_radians();
43        let lat2 = other.y.to_radians();
44        let delta_lat = (other.y - self.y).to_radians();
45        let delta_lon = (other.x - self.x).to_radians();
46
47        let a = (delta_lat / 2.0).sin().powi(2)
48            + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
49        let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
50
51        EARTH_RADIUS_KM * c
52    }
53}
54
55/// Configuration for spatial cross-validation
56#[derive(Debug, Clone)]
57pub struct SpatialValidationConfig {
58    /// Number of folds for cross-validation
59    pub n_splits: usize,
60    /// Minimum distance between training and test samples
61    pub buffer_distance: f64,
62    /// Method for distance calculation
63    pub distance_method: DistanceMethod,
64    /// Clustering method for spatial grouping
65    pub clustering_method: SpatialClusteringMethod,
66    /// Random state for reproducible results
67    pub random_state: Option<u64>,
68    /// Whether to use geographic coordinates (lat/lon)
69    pub geographic: bool,
70}
71
72impl Default for SpatialValidationConfig {
73    fn default() -> Self {
74        Self {
75            n_splits: 5,
76            buffer_distance: 1000.0, // 1km default for geographic data
77            distance_method: DistanceMethod::Euclidean,
78            clustering_method: SpatialClusteringMethod::KMeans,
79            random_state: None,
80            geographic: false,
81        }
82    }
83}
84
85/// Distance calculation methods
86#[derive(Debug, Clone)]
87pub enum DistanceMethod {
88    /// Euclidean
89    Euclidean,
90    /// Haversine
91    Haversine, // For lat/lon coordinates
92    /// Manhattan
93    Manhattan,
94    /// Chebyshev
95    Chebyshev,
96}
97
98/// Spatial clustering methods for grouping
99#[derive(Debug, Clone)]
100pub enum SpatialClusteringMethod {
101    /// KMeans
102    KMeans,
103    /// Grid
104    Grid,
105    /// Hierarchical
106    Hierarchical,
107    /// DBSCAN
108    DBSCAN,
109}
110
111/// Spatial cross-validator that accounts for spatial autocorrelation
112#[derive(Debug, Clone)]
113pub struct SpatialCrossValidator {
114    config: SpatialValidationConfig,
115}
116
117impl SpatialCrossValidator {
118    pub fn new(config: SpatialValidationConfig) -> Self {
119        Self { config }
120    }
121
122    /// Generate spatial cross-validation splits
123    pub fn split(
124        &self,
125        n_samples: usize,
126        coordinates: &[SpatialCoordinate],
127    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
128        if coordinates.len() != n_samples {
129            return Err(SklearsError::InvalidInput(
130                "Number of coordinates must match number of samples".to_string(),
131            ));
132        }
133
134        // Create spatial clusters
135        let clusters = self.create_spatial_clusters(coordinates)?;
136
137        // Generate cross-validation splits based on clusters
138        let splits = self.generate_cluster_splits(&clusters)?;
139
140        // Apply buffer constraints
141        let filtered_splits = self.apply_buffer_constraints(&splits, coordinates)?;
142
143        Ok(filtered_splits)
144    }
145
146    /// Create spatial clusters using the specified method
147    fn create_spatial_clusters(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
148        match self.config.clustering_method {
149            SpatialClusteringMethod::KMeans => self.kmeans_clustering(coordinates),
150            SpatialClusteringMethod::Grid => self.grid_clustering(coordinates),
151            SpatialClusteringMethod::Hierarchical => self.hierarchical_clustering(coordinates),
152            SpatialClusteringMethod::DBSCAN => self.dbscan_clustering(coordinates),
153        }
154    }
155
156    /// K-means clustering for spatial grouping
157    fn kmeans_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
158        let n_samples = coordinates.len();
159        let mut clusters = vec![0; n_samples];
160        let mut centroids = Vec::new();
161
162        // Initialize centroids randomly
163        let mut rng = self.get_rng();
164        for _i in 0..self.config.n_splits {
165            let idx = rng.random_range(0..n_samples);
166            centroids.push(coordinates[idx]);
167        }
168
169        // K-means iterations
170        for _ in 0..100 {
171            // Max iterations
172            let mut new_centroids = vec![SpatialCoordinate::new(0.0, 0.0); self.config.n_splits];
173            let mut cluster_counts = vec![0; self.config.n_splits];
174            let mut changed = false;
175
176            // Assign points to nearest centroid
177            for (i, coord) in coordinates.iter().enumerate() {
178                let mut min_distance = f64::INFINITY;
179                let mut best_cluster = 0;
180
181                for (j, centroid) in centroids.iter().enumerate() {
182                    let distance = self.calculate_distance(coord, centroid);
183                    if distance < min_distance {
184                        min_distance = distance;
185                        best_cluster = j;
186                    }
187                }
188
189                if clusters[i] != best_cluster {
190                    changed = true;
191                    clusters[i] = best_cluster;
192                }
193
194                // Update centroid sum
195                new_centroids[best_cluster].x += coord.x;
196                new_centroids[best_cluster].y += coord.y;
197                if let Some(z) = coord.z {
198                    if new_centroids[best_cluster].z.is_none() {
199                        new_centroids[best_cluster].z = Some(0.0);
200                    }
201                    new_centroids[best_cluster].z = Some(
202                        new_centroids[best_cluster]
203                            .z
204                            .expect("operation should succeed")
205                            + z,
206                    );
207                }
208                cluster_counts[best_cluster] += 1;
209            }
210
211            // Update centroids
212            for (i, count) in cluster_counts.iter().enumerate() {
213                if *count > 0 {
214                    new_centroids[i].x /= *count as f64;
215                    new_centroids[i].y /= *count as f64;
216                    if let Some(z) = new_centroids[i].z {
217                        new_centroids[i].z = Some(z / *count as f64);
218                    }
219                }
220            }
221
222            centroids = new_centroids;
223
224            if !changed {
225                break;
226            }
227        }
228
229        Ok(clusters)
230    }
231
232    /// Grid-based clustering for regular spatial division
233    fn grid_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
234        // Find bounds
235        let min_x = coordinates
236            .iter()
237            .map(|c| c.x)
238            .fold(f64::INFINITY, f64::min);
239        let max_x = coordinates
240            .iter()
241            .map(|c| c.x)
242            .fold(f64::NEG_INFINITY, f64::max);
243        let min_y = coordinates
244            .iter()
245            .map(|c| c.y)
246            .fold(f64::INFINITY, f64::min);
247        let max_y = coordinates
248            .iter()
249            .map(|c| c.y)
250            .fold(f64::NEG_INFINITY, f64::max);
251
252        // Calculate grid dimensions
253        let grid_size = (self.config.n_splits as f64).sqrt().ceil() as usize;
254        let x_step = (max_x - min_x) / grid_size as f64;
255        let y_step = (max_y - min_y) / grid_size as f64;
256
257        let mut clusters = Vec::new();
258
259        for coord in coordinates {
260            let x_grid = ((coord.x - min_x) / x_step).floor() as usize;
261            let y_grid = ((coord.y - min_y) / y_step).floor() as usize;
262
263            let x_grid = x_grid.min(grid_size - 1);
264            let y_grid = y_grid.min(grid_size - 1);
265
266            let cluster_id = (y_grid * grid_size + x_grid) % self.config.n_splits;
267            clusters.push(cluster_id);
268        }
269
270        Ok(clusters)
271    }
272
273    /// Hierarchical clustering for spatial grouping
274    fn hierarchical_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
275        let n_samples = coordinates.len();
276
277        // Calculate distance matrix
278        let mut distances = vec![vec![0.0; n_samples]; n_samples];
279        for i in 0..n_samples {
280            for j in i + 1..n_samples {
281                let dist = self.calculate_distance(&coordinates[i], &coordinates[j]);
282                distances[i][j] = dist;
283                distances[j][i] = dist;
284            }
285        }
286
287        // Simple hierarchical clustering
288        let mut clusters = (0..n_samples).collect::<Vec<_>>();
289        let mut cluster_map = (0..n_samples).collect::<Vec<_>>();
290
291        // Merge closest clusters until we have n_splits clusters
292        while clusters.len() > self.config.n_splits {
293            let mut min_distance = f64::INFINITY;
294            let mut merge_i = 0;
295            let mut merge_j = 0;
296
297            for i in 0..clusters.len() {
298                for j in i + 1..clusters.len() {
299                    let dist = distances[clusters[i]][clusters[j]];
300                    if dist < min_distance {
301                        min_distance = dist;
302                        merge_i = i;
303                        merge_j = j;
304                    }
305                }
306            }
307
308            // Merge clusters
309            let cluster_j = clusters.remove(merge_j);
310            let cluster_i = clusters[merge_i];
311
312            // Update cluster assignments
313            for assignment in &mut cluster_map {
314                if *assignment == cluster_j {
315                    *assignment = cluster_i;
316                }
317            }
318        }
319
320        // Map cluster IDs to 0..n_splits
321        let unique_clusters: Vec<_> = cluster_map
322            .iter()
323            .cloned()
324            .collect::<HashSet<_>>()
325            .into_iter()
326            .collect();
327        let mut final_clusters = vec![0; n_samples];
328
329        for (i, &cluster_id) in cluster_map.iter().enumerate() {
330            final_clusters[i] = unique_clusters
331                .iter()
332                .position(|&x| x == cluster_id)
333                .unwrap_or(0);
334        }
335
336        Ok(final_clusters)
337    }
338
339    /// DBSCAN clustering for density-based spatial grouping
340    fn dbscan_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
341        let n_samples = coordinates.len();
342        let eps = self.config.buffer_distance / 2.0;
343        let min_pts = (n_samples / self.config.n_splits).max(2);
344
345        let mut clusters = vec![None; n_samples];
346        let mut visited = vec![false; n_samples];
347        let mut cluster_id = 0;
348
349        for i in 0..n_samples {
350            if visited[i] {
351                continue;
352            }
353
354            visited[i] = true;
355            let neighbors = self.find_neighbors(i, coordinates, eps);
356
357            if neighbors.len() < min_pts {
358                clusters[i] = Some(usize::MAX); // Noise point
359            } else {
360                self.expand_cluster(
361                    i,
362                    &neighbors,
363                    cluster_id,
364                    coordinates,
365                    eps,
366                    min_pts,
367                    &mut clusters,
368                    &mut visited,
369                );
370                cluster_id += 1;
371            }
372        }
373
374        // Convert to 0..n_splits range
375        let max_clusters = cluster_id.min(self.config.n_splits);
376        let mut final_clusters = vec![0; n_samples];
377
378        for (i, cluster) in clusters.iter().enumerate() {
379            final_clusters[i] = match cluster {
380                Some(id) if *id != usize::MAX => *id % max_clusters,
381                _ => i % self.config.n_splits, // Assign noise points randomly
382            };
383        }
384
385        Ok(final_clusters)
386    }
387
388    fn find_neighbors(
389        &self,
390        point: usize,
391        coordinates: &[SpatialCoordinate],
392        eps: f64,
393    ) -> Vec<usize> {
394        let mut neighbors = Vec::new();
395        for (i, coord) in coordinates.iter().enumerate() {
396            if i != point && self.calculate_distance(&coordinates[point], coord) <= eps {
397                neighbors.push(i);
398            }
399        }
400        neighbors
401    }
402
403    #[allow(clippy::too_many_arguments)]
404    fn expand_cluster(
405        &self,
406        point: usize,
407        neighbors: &[usize],
408        cluster_id: usize,
409        coordinates: &[SpatialCoordinate],
410        eps: f64,
411        min_pts: usize,
412        clusters: &mut [Option<usize>],
413        visited: &mut [bool],
414    ) {
415        clusters[point] = Some(cluster_id);
416        let mut seed_set = neighbors.to_vec();
417        let mut i = 0;
418
419        while i < seed_set.len() {
420            let q = seed_set[i];
421
422            if !visited[q] {
423                visited[q] = true;
424                let q_neighbors = self.find_neighbors(q, coordinates, eps);
425
426                if q_neighbors.len() >= min_pts {
427                    seed_set.extend(q_neighbors);
428                }
429            }
430
431            if clusters[q].is_none() {
432                clusters[q] = Some(cluster_id);
433            }
434
435            i += 1;
436        }
437    }
438
439    /// Generate cross-validation splits based on spatial clusters
440    fn generate_cluster_splits(&self, clusters: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
441        let mut splits = Vec::new();
442
443        for test_cluster in 0..self.config.n_splits {
444            let mut train_indices = Vec::new();
445            let mut test_indices = Vec::new();
446
447            for (i, &cluster) in clusters.iter().enumerate() {
448                if cluster == test_cluster {
449                    test_indices.push(i);
450                } else {
451                    train_indices.push(i);
452                }
453            }
454
455            if !train_indices.is_empty() && !test_indices.is_empty() {
456                splits.push((train_indices, test_indices));
457            }
458        }
459
460        Ok(splits)
461    }
462
463    /// Apply buffer constraints to prevent spatial leakage
464    fn apply_buffer_constraints(
465        &self,
466        splits: &[(Vec<usize>, Vec<usize>)],
467        coordinates: &[SpatialCoordinate],
468    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
469        let mut filtered_splits = Vec::new();
470
471        for (train_indices, test_indices) in splits {
472            let mut filtered_train = Vec::new();
473
474            for &train_idx in train_indices {
475                let mut too_close = false;
476
477                for &test_idx in test_indices {
478                    let distance =
479                        self.calculate_distance(&coordinates[train_idx], &coordinates[test_idx]);
480
481                    if distance < self.config.buffer_distance {
482                        too_close = true;
483                        break;
484                    }
485                }
486
487                if !too_close {
488                    filtered_train.push(train_idx);
489                }
490            }
491
492            if !filtered_train.is_empty() && !test_indices.is_empty() {
493                filtered_splits.push((filtered_train, test_indices.clone()));
494            }
495        }
496
497        Ok(filtered_splits)
498    }
499
500    /// Calculate distance between two coordinates
501    fn calculate_distance(&self, coord1: &SpatialCoordinate, coord2: &SpatialCoordinate) -> f64 {
502        match self.config.distance_method {
503            DistanceMethod::Euclidean => coord1.distance(coord2),
504            DistanceMethod::Haversine => coord1.haversine_distance(coord2),
505            DistanceMethod::Manhattan => (coord1.x - coord2.x).abs() + (coord1.y - coord2.y).abs(),
506            DistanceMethod::Chebyshev => {
507                (coord1.x - coord2.x).abs().max((coord1.y - coord2.y).abs())
508            }
509        }
510    }
511
512    fn get_rng(&self) -> impl scirs2_core::random::Rng {
513        use scirs2_core::random::rngs::StdRng;
514        use scirs2_core::random::SeedableRng;
515        match self.config.random_state {
516            Some(seed) => StdRng::seed_from_u64(seed),
517            None => StdRng::seed_from_u64(42),
518        }
519    }
520}
521
522/// Leave-one-region-out cross-validator for spatial data
523#[derive(Debug, Clone)]
524pub struct LeaveOneRegionOut {
525    region_labels: Vec<usize>,
526}
527
528impl LeaveOneRegionOut {
529    pub fn new(region_labels: Vec<usize>) -> Self {
530        Self { region_labels }
531    }
532
533    /// Generate splits where each region is used as test set once
534    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
535        if self.region_labels.len() != n_samples {
536            return Err(SklearsError::InvalidInput(
537                "Region labels length must match number of samples".to_string(),
538            ));
539        }
540
541        let unique_regions: HashSet<usize> = self.region_labels.iter().cloned().collect();
542        let mut splits = Vec::new();
543
544        for test_region in unique_regions {
545            let mut train_indices = Vec::new();
546            let mut test_indices = Vec::new();
547
548            for (i, &region) in self.region_labels.iter().enumerate() {
549                if region == test_region {
550                    test_indices.push(i);
551                } else {
552                    train_indices.push(i);
553                }
554            }
555
556            if !train_indices.is_empty() && !test_indices.is_empty() {
557                splits.push((train_indices, test_indices));
558            }
559        }
560
561        Ok(splits)
562    }
563}
564
565#[allow(non_snake_case)]
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_spatial_coordinate_distance() {
572        let coord1 = SpatialCoordinate::new(0.0, 0.0);
573        let coord2 = SpatialCoordinate::new(3.0, 4.0);
574
575        assert!((coord1.distance(&coord2) - 5.0).abs() < 1e-10);
576    }
577
578    #[test]
579    fn test_spatial_cross_validator() {
580        let config = SpatialValidationConfig {
581            buffer_distance: 1.0, // Use a reasonable buffer distance for test data
582            ..Default::default()
583        };
584        let cv = SpatialCrossValidator::new(config);
585
586        // Create simple grid of coordinates
587        let mut coordinates = Vec::new();
588        for i in 0..25 {
589            let x = (i % 5) as f64;
590            let y = (i / 5) as f64;
591            coordinates.push(SpatialCoordinate::new(x, y));
592        }
593
594        let splits = cv
595            .split(25, &coordinates)
596            .expect("operation should succeed");
597        assert!(!splits.is_empty(), "Should generate at least one split");
598
599        for (train_indices, test_indices) in &splits {
600            assert!(
601                !train_indices.is_empty(),
602                "Training set should not be empty"
603            );
604            assert!(!test_indices.is_empty(), "Test set should not be empty");
605        }
606    }
607
608    #[test]
609    fn test_leave_one_region_out() {
610        let region_labels = vec![0, 0, 1, 1, 2, 2];
611        let cv = LeaveOneRegionOut::new(region_labels);
612
613        let splits = cv.split(6).expect("operation should succeed");
614        assert_eq!(splits.len(), 3, "Should have 3 splits for 3 regions");
615
616        for (train_indices, test_indices) in &splits {
617            assert!(
618                !train_indices.is_empty(),
619                "Training set should not be empty"
620            );
621            assert!(!test_indices.is_empty(), "Test set should not be empty");
622        }
623    }
624}