sklears_model_selection/
spatial_validation.rs

1//! Spatial cross-validation for geographic and spatial data
2//!
3//! This module provides cross-validation methods that account for spatial
4//! autocorrelation and geographic dependencies in spatial datasets.
5
6use scirs2_core::random::Rng;
7use sklears_core::error::{Result, SklearsError};
8use std::collections::HashSet;
9
10/// Spatial coordinates for geographic data
11#[derive(Debug, Clone, Copy)]
12pub struct SpatialCoordinate {
13    pub x: f64,
14    pub y: f64,
15    pub z: Option<f64>, // For 3D spatial data
16}
17
18impl SpatialCoordinate {
19    pub fn new(x: f64, y: f64) -> Self {
20        Self { x, y, z: None }
21    }
22
23    pub fn new_3d(x: f64, y: f64, z: f64) -> Self {
24        Self { x, y, z: Some(z) }
25    }
26
27    /// Calculate Euclidean distance between two coordinates
28    pub fn distance(&self, other: &SpatialCoordinate) -> f64 {
29        let dx = self.x - other.x;
30        let dy = self.y - other.y;
31        let dz = match (self.z, other.z) {
32            (Some(z1), Some(z2)) => z1 - z2,
33            _ => 0.0,
34        };
35        (dx * dx + dy * dy + dz * dz).sqrt()
36    }
37
38    /// Calculate Haversine distance for lat/lon coordinates (in kilometers)
39    pub fn haversine_distance(&self, other: &SpatialCoordinate) -> f64 {
40        const EARTH_RADIUS_KM: f64 = 6371.0;
41
42        let lat1 = self.y.to_radians();
43        let lat2 = other.y.to_radians();
44        let delta_lat = (other.y - self.y).to_radians();
45        let delta_lon = (other.x - self.x).to_radians();
46
47        let a = (delta_lat / 2.0).sin().powi(2)
48            + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
49        let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
50
51        EARTH_RADIUS_KM * c
52    }
53}
54
55/// Configuration for spatial cross-validation
56#[derive(Debug, Clone)]
57pub struct SpatialValidationConfig {
58    /// Number of folds for cross-validation
59    pub n_splits: usize,
60    /// Minimum distance between training and test samples
61    pub buffer_distance: f64,
62    /// Method for distance calculation
63    pub distance_method: DistanceMethod,
64    /// Clustering method for spatial grouping
65    pub clustering_method: SpatialClusteringMethod,
66    /// Random state for reproducible results
67    pub random_state: Option<u64>,
68    /// Whether to use geographic coordinates (lat/lon)
69    pub geographic: bool,
70}
71
72impl Default for SpatialValidationConfig {
73    fn default() -> Self {
74        Self {
75            n_splits: 5,
76            buffer_distance: 1000.0, // 1km default for geographic data
77            distance_method: DistanceMethod::Euclidean,
78            clustering_method: SpatialClusteringMethod::KMeans,
79            random_state: None,
80            geographic: false,
81        }
82    }
83}
84
85/// Distance calculation methods
86#[derive(Debug, Clone)]
87pub enum DistanceMethod {
88    /// Euclidean
89    Euclidean,
90    /// Haversine
91    Haversine, // For lat/lon coordinates
92    /// Manhattan
93    Manhattan,
94    /// Chebyshev
95    Chebyshev,
96}
97
98/// Spatial clustering methods for grouping
99#[derive(Debug, Clone)]
100pub enum SpatialClusteringMethod {
101    /// KMeans
102    KMeans,
103    /// Grid
104    Grid,
105    /// Hierarchical
106    Hierarchical,
107    /// DBSCAN
108    DBSCAN,
109}
110
111/// Spatial cross-validator that accounts for spatial autocorrelation
112#[derive(Debug, Clone)]
113pub struct SpatialCrossValidator {
114    config: SpatialValidationConfig,
115}
116
117impl SpatialCrossValidator {
118    pub fn new(config: SpatialValidationConfig) -> Self {
119        Self { config }
120    }
121
122    /// Generate spatial cross-validation splits
123    pub fn split(
124        &self,
125        n_samples: usize,
126        coordinates: &[SpatialCoordinate],
127    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
128        if coordinates.len() != n_samples {
129            return Err(SklearsError::InvalidInput(
130                "Number of coordinates must match number of samples".to_string(),
131            ));
132        }
133
134        // Create spatial clusters
135        let clusters = self.create_spatial_clusters(coordinates)?;
136
137        // Generate cross-validation splits based on clusters
138        let splits = self.generate_cluster_splits(&clusters)?;
139
140        // Apply buffer constraints
141        let filtered_splits = self.apply_buffer_constraints(&splits, coordinates)?;
142
143        Ok(filtered_splits)
144    }
145
146    /// Create spatial clusters using the specified method
147    fn create_spatial_clusters(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
148        match self.config.clustering_method {
149            SpatialClusteringMethod::KMeans => self.kmeans_clustering(coordinates),
150            SpatialClusteringMethod::Grid => self.grid_clustering(coordinates),
151            SpatialClusteringMethod::Hierarchical => self.hierarchical_clustering(coordinates),
152            SpatialClusteringMethod::DBSCAN => self.dbscan_clustering(coordinates),
153        }
154    }
155
156    /// K-means clustering for spatial grouping
157    fn kmeans_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
158        let n_samples = coordinates.len();
159        let mut clusters = vec![0; n_samples];
160        let mut centroids = Vec::new();
161
162        // Initialize centroids randomly
163        let mut rng = self.get_rng();
164        for _i in 0..self.config.n_splits {
165            let idx = rng.gen_range(0..n_samples);
166            centroids.push(coordinates[idx]);
167        }
168
169        // K-means iterations
170        for _ in 0..100 {
171            // Max iterations
172            let mut new_centroids = vec![SpatialCoordinate::new(0.0, 0.0); self.config.n_splits];
173            let mut cluster_counts = vec![0; self.config.n_splits];
174            let mut changed = false;
175
176            // Assign points to nearest centroid
177            for (i, coord) in coordinates.iter().enumerate() {
178                let mut min_distance = f64::INFINITY;
179                let mut best_cluster = 0;
180
181                for (j, centroid) in centroids.iter().enumerate() {
182                    let distance = self.calculate_distance(coord, centroid);
183                    if distance < min_distance {
184                        min_distance = distance;
185                        best_cluster = j;
186                    }
187                }
188
189                if clusters[i] != best_cluster {
190                    changed = true;
191                    clusters[i] = best_cluster;
192                }
193
194                // Update centroid sum
195                new_centroids[best_cluster].x += coord.x;
196                new_centroids[best_cluster].y += coord.y;
197                if let Some(z) = coord.z {
198                    if new_centroids[best_cluster].z.is_none() {
199                        new_centroids[best_cluster].z = Some(0.0);
200                    }
201                    new_centroids[best_cluster].z =
202                        Some(new_centroids[best_cluster].z.unwrap() + z);
203                }
204                cluster_counts[best_cluster] += 1;
205            }
206
207            // Update centroids
208            for (i, count) in cluster_counts.iter().enumerate() {
209                if *count > 0 {
210                    new_centroids[i].x /= *count as f64;
211                    new_centroids[i].y /= *count as f64;
212                    if let Some(z) = new_centroids[i].z {
213                        new_centroids[i].z = Some(z / *count as f64);
214                    }
215                }
216            }
217
218            centroids = new_centroids;
219
220            if !changed {
221                break;
222            }
223        }
224
225        Ok(clusters)
226    }
227
228    /// Grid-based clustering for regular spatial division
229    fn grid_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
230        // Find bounds
231        let min_x = coordinates
232            .iter()
233            .map(|c| c.x)
234            .fold(f64::INFINITY, f64::min);
235        let max_x = coordinates
236            .iter()
237            .map(|c| c.x)
238            .fold(f64::NEG_INFINITY, f64::max);
239        let min_y = coordinates
240            .iter()
241            .map(|c| c.y)
242            .fold(f64::INFINITY, f64::min);
243        let max_y = coordinates
244            .iter()
245            .map(|c| c.y)
246            .fold(f64::NEG_INFINITY, f64::max);
247
248        // Calculate grid dimensions
249        let grid_size = (self.config.n_splits as f64).sqrt().ceil() as usize;
250        let x_step = (max_x - min_x) / grid_size as f64;
251        let y_step = (max_y - min_y) / grid_size as f64;
252
253        let mut clusters = Vec::new();
254
255        for coord in coordinates {
256            let x_grid = ((coord.x - min_x) / x_step).floor() as usize;
257            let y_grid = ((coord.y - min_y) / y_step).floor() as usize;
258
259            let x_grid = x_grid.min(grid_size - 1);
260            let y_grid = y_grid.min(grid_size - 1);
261
262            let cluster_id = (y_grid * grid_size + x_grid) % self.config.n_splits;
263            clusters.push(cluster_id);
264        }
265
266        Ok(clusters)
267    }
268
269    /// Hierarchical clustering for spatial grouping
270    fn hierarchical_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
271        let n_samples = coordinates.len();
272
273        // Calculate distance matrix
274        let mut distances = vec![vec![0.0; n_samples]; n_samples];
275        for i in 0..n_samples {
276            for j in i + 1..n_samples {
277                let dist = self.calculate_distance(&coordinates[i], &coordinates[j]);
278                distances[i][j] = dist;
279                distances[j][i] = dist;
280            }
281        }
282
283        // Simple hierarchical clustering
284        let mut clusters = (0..n_samples).collect::<Vec<_>>();
285        let mut cluster_map = (0..n_samples).collect::<Vec<_>>();
286
287        // Merge closest clusters until we have n_splits clusters
288        while clusters.len() > self.config.n_splits {
289            let mut min_distance = f64::INFINITY;
290            let mut merge_i = 0;
291            let mut merge_j = 0;
292
293            for i in 0..clusters.len() {
294                for j in i + 1..clusters.len() {
295                    let dist = distances[clusters[i]][clusters[j]];
296                    if dist < min_distance {
297                        min_distance = dist;
298                        merge_i = i;
299                        merge_j = j;
300                    }
301                }
302            }
303
304            // Merge clusters
305            let cluster_j = clusters.remove(merge_j);
306            let cluster_i = clusters[merge_i];
307
308            // Update cluster assignments
309            for assignment in &mut cluster_map {
310                if *assignment == cluster_j {
311                    *assignment = cluster_i;
312                }
313            }
314        }
315
316        // Map cluster IDs to 0..n_splits
317        let unique_clusters: Vec<_> = cluster_map
318            .iter()
319            .cloned()
320            .collect::<HashSet<_>>()
321            .into_iter()
322            .collect();
323        let mut final_clusters = vec![0; n_samples];
324
325        for (i, &cluster_id) in cluster_map.iter().enumerate() {
326            final_clusters[i] = unique_clusters
327                .iter()
328                .position(|&x| x == cluster_id)
329                .unwrap_or(0);
330        }
331
332        Ok(final_clusters)
333    }
334
335    /// DBSCAN clustering for density-based spatial grouping
336    fn dbscan_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
337        let n_samples = coordinates.len();
338        let eps = self.config.buffer_distance / 2.0;
339        let min_pts = (n_samples / self.config.n_splits).max(2);
340
341        let mut clusters = vec![None; n_samples];
342        let mut visited = vec![false; n_samples];
343        let mut cluster_id = 0;
344
345        for i in 0..n_samples {
346            if visited[i] {
347                continue;
348            }
349
350            visited[i] = true;
351            let neighbors = self.find_neighbors(i, coordinates, eps);
352
353            if neighbors.len() < min_pts {
354                clusters[i] = Some(usize::MAX); // Noise point
355            } else {
356                self.expand_cluster(
357                    i,
358                    &neighbors,
359                    cluster_id,
360                    coordinates,
361                    eps,
362                    min_pts,
363                    &mut clusters,
364                    &mut visited,
365                );
366                cluster_id += 1;
367            }
368        }
369
370        // Convert to 0..n_splits range
371        let max_clusters = cluster_id.min(self.config.n_splits);
372        let mut final_clusters = vec![0; n_samples];
373
374        for (i, cluster) in clusters.iter().enumerate() {
375            final_clusters[i] = match cluster {
376                Some(id) if *id != usize::MAX => *id % max_clusters,
377                _ => i % self.config.n_splits, // Assign noise points randomly
378            };
379        }
380
381        Ok(final_clusters)
382    }
383
384    fn find_neighbors(
385        &self,
386        point: usize,
387        coordinates: &[SpatialCoordinate],
388        eps: f64,
389    ) -> Vec<usize> {
390        let mut neighbors = Vec::new();
391        for (i, coord) in coordinates.iter().enumerate() {
392            if i != point && self.calculate_distance(&coordinates[point], coord) <= eps {
393                neighbors.push(i);
394            }
395        }
396        neighbors
397    }
398
399    #[allow(clippy::too_many_arguments)]
400    fn expand_cluster(
401        &self,
402        point: usize,
403        neighbors: &[usize],
404        cluster_id: usize,
405        coordinates: &[SpatialCoordinate],
406        eps: f64,
407        min_pts: usize,
408        clusters: &mut [Option<usize>],
409        visited: &mut [bool],
410    ) {
411        clusters[point] = Some(cluster_id);
412        let mut seed_set = neighbors.to_vec();
413        let mut i = 0;
414
415        while i < seed_set.len() {
416            let q = seed_set[i];
417
418            if !visited[q] {
419                visited[q] = true;
420                let q_neighbors = self.find_neighbors(q, coordinates, eps);
421
422                if q_neighbors.len() >= min_pts {
423                    seed_set.extend(q_neighbors);
424                }
425            }
426
427            if clusters[q].is_none() {
428                clusters[q] = Some(cluster_id);
429            }
430
431            i += 1;
432        }
433    }
434
435    /// Generate cross-validation splits based on spatial clusters
436    fn generate_cluster_splits(&self, clusters: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
437        let mut splits = Vec::new();
438
439        for test_cluster in 0..self.config.n_splits {
440            let mut train_indices = Vec::new();
441            let mut test_indices = Vec::new();
442
443            for (i, &cluster) in clusters.iter().enumerate() {
444                if cluster == test_cluster {
445                    test_indices.push(i);
446                } else {
447                    train_indices.push(i);
448                }
449            }
450
451            if !train_indices.is_empty() && !test_indices.is_empty() {
452                splits.push((train_indices, test_indices));
453            }
454        }
455
456        Ok(splits)
457    }
458
459    /// Apply buffer constraints to prevent spatial leakage
460    fn apply_buffer_constraints(
461        &self,
462        splits: &[(Vec<usize>, Vec<usize>)],
463        coordinates: &[SpatialCoordinate],
464    ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
465        let mut filtered_splits = Vec::new();
466
467        for (train_indices, test_indices) in splits {
468            let mut filtered_train = Vec::new();
469
470            for &train_idx in train_indices {
471                let mut too_close = false;
472
473                for &test_idx in test_indices {
474                    let distance =
475                        self.calculate_distance(&coordinates[train_idx], &coordinates[test_idx]);
476
477                    if distance < self.config.buffer_distance {
478                        too_close = true;
479                        break;
480                    }
481                }
482
483                if !too_close {
484                    filtered_train.push(train_idx);
485                }
486            }
487
488            if !filtered_train.is_empty() && !test_indices.is_empty() {
489                filtered_splits.push((filtered_train, test_indices.clone()));
490            }
491        }
492
493        Ok(filtered_splits)
494    }
495
496    /// Calculate distance between two coordinates
497    fn calculate_distance(&self, coord1: &SpatialCoordinate, coord2: &SpatialCoordinate) -> f64 {
498        match self.config.distance_method {
499            DistanceMethod::Euclidean => coord1.distance(coord2),
500            DistanceMethod::Haversine => coord1.haversine_distance(coord2),
501            DistanceMethod::Manhattan => (coord1.x - coord2.x).abs() + (coord1.y - coord2.y).abs(),
502            DistanceMethod::Chebyshev => {
503                (coord1.x - coord2.x).abs().max((coord1.y - coord2.y).abs())
504            }
505        }
506    }
507
508    fn get_rng(&self) -> impl scirs2_core::random::Rng {
509        use scirs2_core::random::rngs::StdRng;
510        use scirs2_core::random::SeedableRng;
511        match self.config.random_state {
512            Some(seed) => StdRng::seed_from_u64(seed),
513            None => StdRng::seed_from_u64(42),
514        }
515    }
516}
517
518/// Leave-one-region-out cross-validator for spatial data
519#[derive(Debug, Clone)]
520pub struct LeaveOneRegionOut {
521    region_labels: Vec<usize>,
522}
523
524impl LeaveOneRegionOut {
525    pub fn new(region_labels: Vec<usize>) -> Self {
526        Self { region_labels }
527    }
528
529    /// Generate splits where each region is used as test set once
530    pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
531        if self.region_labels.len() != n_samples {
532            return Err(SklearsError::InvalidInput(
533                "Region labels length must match number of samples".to_string(),
534            ));
535        }
536
537        let unique_regions: HashSet<usize> = self.region_labels.iter().cloned().collect();
538        let mut splits = Vec::new();
539
540        for test_region in unique_regions {
541            let mut train_indices = Vec::new();
542            let mut test_indices = Vec::new();
543
544            for (i, &region) in self.region_labels.iter().enumerate() {
545                if region == test_region {
546                    test_indices.push(i);
547                } else {
548                    train_indices.push(i);
549                }
550            }
551
552            if !train_indices.is_empty() && !test_indices.is_empty() {
553                splits.push((train_indices, test_indices));
554            }
555        }
556
557        Ok(splits)
558    }
559}
560
561#[allow(non_snake_case)]
562#[cfg(test)]
563mod tests {
564    use super::*;
565
566    #[test]
567    fn test_spatial_coordinate_distance() {
568        let coord1 = SpatialCoordinate::new(0.0, 0.0);
569        let coord2 = SpatialCoordinate::new(3.0, 4.0);
570
571        assert!((coord1.distance(&coord2) - 5.0).abs() < 1e-10);
572    }
573
574    #[test]
575    fn test_spatial_cross_validator() {
576        let config = SpatialValidationConfig {
577            buffer_distance: 1.0, // Use a reasonable buffer distance for test data
578            ..Default::default()
579        };
580        let cv = SpatialCrossValidator::new(config);
581
582        // Create simple grid of coordinates
583        let mut coordinates = Vec::new();
584        for i in 0..25 {
585            let x = (i % 5) as f64;
586            let y = (i / 5) as f64;
587            coordinates.push(SpatialCoordinate::new(x, y));
588        }
589
590        let splits = cv.split(25, &coordinates).unwrap();
591        assert!(!splits.is_empty(), "Should generate at least one split");
592
593        for (train_indices, test_indices) in &splits {
594            assert!(
595                !train_indices.is_empty(),
596                "Training set should not be empty"
597            );
598            assert!(!test_indices.is_empty(), "Test set should not be empty");
599        }
600    }
601
602    #[test]
603    fn test_leave_one_region_out() {
604        let region_labels = vec![0, 0, 1, 1, 2, 2];
605        let cv = LeaveOneRegionOut::new(region_labels);
606
607        let splits = cv.split(6).unwrap();
608        assert_eq!(splits.len(), 3, "Should have 3 splits for 3 regions");
609
610        for (train_indices, test_indices) in &splits {
611            assert!(
612                !train_indices.is_empty(),
613                "Training set should not be empty"
614            );
615            assert!(!test_indices.is_empty(), "Test set should not be empty");
616        }
617    }
618}