1use scirs2_core::random::Rng;
7use sklears_core::error::{Result, SklearsError};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone, Copy)]
12pub struct SpatialCoordinate {
13 pub x: f64,
14 pub y: f64,
15 pub z: Option<f64>, }
17
18impl SpatialCoordinate {
19 pub fn new(x: f64, y: f64) -> Self {
20 Self { x, y, z: None }
21 }
22
23 pub fn new_3d(x: f64, y: f64, z: f64) -> Self {
24 Self { x, y, z: Some(z) }
25 }
26
27 pub fn distance(&self, other: &SpatialCoordinate) -> f64 {
29 let dx = self.x - other.x;
30 let dy = self.y - other.y;
31 let dz = match (self.z, other.z) {
32 (Some(z1), Some(z2)) => z1 - z2,
33 _ => 0.0,
34 };
35 (dx * dx + dy * dy + dz * dz).sqrt()
36 }
37
38 pub fn haversine_distance(&self, other: &SpatialCoordinate) -> f64 {
40 const EARTH_RADIUS_KM: f64 = 6371.0;
41
42 let lat1 = self.y.to_radians();
43 let lat2 = other.y.to_radians();
44 let delta_lat = (other.y - self.y).to_radians();
45 let delta_lon = (other.x - self.x).to_radians();
46
47 let a = (delta_lat / 2.0).sin().powi(2)
48 + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
49 let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
50
51 EARTH_RADIUS_KM * c
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct SpatialValidationConfig {
58 pub n_splits: usize,
60 pub buffer_distance: f64,
62 pub distance_method: DistanceMethod,
64 pub clustering_method: SpatialClusteringMethod,
66 pub random_state: Option<u64>,
68 pub geographic: bool,
70}
71
72impl Default for SpatialValidationConfig {
73 fn default() -> Self {
74 Self {
75 n_splits: 5,
76 buffer_distance: 1000.0, distance_method: DistanceMethod::Euclidean,
78 clustering_method: SpatialClusteringMethod::KMeans,
79 random_state: None,
80 geographic: false,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub enum DistanceMethod {
88 Euclidean,
90 Haversine, Manhattan,
94 Chebyshev,
96}
97
98#[derive(Debug, Clone)]
100pub enum SpatialClusteringMethod {
101 KMeans,
103 Grid,
105 Hierarchical,
107 DBSCAN,
109}
110
111#[derive(Debug, Clone)]
113pub struct SpatialCrossValidator {
114 config: SpatialValidationConfig,
115}
116
117impl SpatialCrossValidator {
118 pub fn new(config: SpatialValidationConfig) -> Self {
119 Self { config }
120 }
121
122 pub fn split(
124 &self,
125 n_samples: usize,
126 coordinates: &[SpatialCoordinate],
127 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
128 if coordinates.len() != n_samples {
129 return Err(SklearsError::InvalidInput(
130 "Number of coordinates must match number of samples".to_string(),
131 ));
132 }
133
134 let clusters = self.create_spatial_clusters(coordinates)?;
136
137 let splits = self.generate_cluster_splits(&clusters)?;
139
140 let filtered_splits = self.apply_buffer_constraints(&splits, coordinates)?;
142
143 Ok(filtered_splits)
144 }
145
146 fn create_spatial_clusters(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
148 match self.config.clustering_method {
149 SpatialClusteringMethod::KMeans => self.kmeans_clustering(coordinates),
150 SpatialClusteringMethod::Grid => self.grid_clustering(coordinates),
151 SpatialClusteringMethod::Hierarchical => self.hierarchical_clustering(coordinates),
152 SpatialClusteringMethod::DBSCAN => self.dbscan_clustering(coordinates),
153 }
154 }
155
156 fn kmeans_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
158 let n_samples = coordinates.len();
159 let mut clusters = vec![0; n_samples];
160 let mut centroids = Vec::new();
161
162 let mut rng = self.get_rng();
164 for _i in 0..self.config.n_splits {
165 let idx = rng.gen_range(0..n_samples);
166 centroids.push(coordinates[idx]);
167 }
168
169 for _ in 0..100 {
171 let mut new_centroids = vec![SpatialCoordinate::new(0.0, 0.0); self.config.n_splits];
173 let mut cluster_counts = vec![0; self.config.n_splits];
174 let mut changed = false;
175
176 for (i, coord) in coordinates.iter().enumerate() {
178 let mut min_distance = f64::INFINITY;
179 let mut best_cluster = 0;
180
181 for (j, centroid) in centroids.iter().enumerate() {
182 let distance = self.calculate_distance(coord, centroid);
183 if distance < min_distance {
184 min_distance = distance;
185 best_cluster = j;
186 }
187 }
188
189 if clusters[i] != best_cluster {
190 changed = true;
191 clusters[i] = best_cluster;
192 }
193
194 new_centroids[best_cluster].x += coord.x;
196 new_centroids[best_cluster].y += coord.y;
197 if let Some(z) = coord.z {
198 if new_centroids[best_cluster].z.is_none() {
199 new_centroids[best_cluster].z = Some(0.0);
200 }
201 new_centroids[best_cluster].z =
202 Some(new_centroids[best_cluster].z.unwrap() + z);
203 }
204 cluster_counts[best_cluster] += 1;
205 }
206
207 for (i, count) in cluster_counts.iter().enumerate() {
209 if *count > 0 {
210 new_centroids[i].x /= *count as f64;
211 new_centroids[i].y /= *count as f64;
212 if let Some(z) = new_centroids[i].z {
213 new_centroids[i].z = Some(z / *count as f64);
214 }
215 }
216 }
217
218 centroids = new_centroids;
219
220 if !changed {
221 break;
222 }
223 }
224
225 Ok(clusters)
226 }
227
228 fn grid_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
230 let min_x = coordinates
232 .iter()
233 .map(|c| c.x)
234 .fold(f64::INFINITY, f64::min);
235 let max_x = coordinates
236 .iter()
237 .map(|c| c.x)
238 .fold(f64::NEG_INFINITY, f64::max);
239 let min_y = coordinates
240 .iter()
241 .map(|c| c.y)
242 .fold(f64::INFINITY, f64::min);
243 let max_y = coordinates
244 .iter()
245 .map(|c| c.y)
246 .fold(f64::NEG_INFINITY, f64::max);
247
248 let grid_size = (self.config.n_splits as f64).sqrt().ceil() as usize;
250 let x_step = (max_x - min_x) / grid_size as f64;
251 let y_step = (max_y - min_y) / grid_size as f64;
252
253 let mut clusters = Vec::new();
254
255 for coord in coordinates {
256 let x_grid = ((coord.x - min_x) / x_step).floor() as usize;
257 let y_grid = ((coord.y - min_y) / y_step).floor() as usize;
258
259 let x_grid = x_grid.min(grid_size - 1);
260 let y_grid = y_grid.min(grid_size - 1);
261
262 let cluster_id = (y_grid * grid_size + x_grid) % self.config.n_splits;
263 clusters.push(cluster_id);
264 }
265
266 Ok(clusters)
267 }
268
269 fn hierarchical_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
271 let n_samples = coordinates.len();
272
273 let mut distances = vec![vec![0.0; n_samples]; n_samples];
275 for i in 0..n_samples {
276 for j in i + 1..n_samples {
277 let dist = self.calculate_distance(&coordinates[i], &coordinates[j]);
278 distances[i][j] = dist;
279 distances[j][i] = dist;
280 }
281 }
282
283 let mut clusters = (0..n_samples).collect::<Vec<_>>();
285 let mut cluster_map = (0..n_samples).collect::<Vec<_>>();
286
287 while clusters.len() > self.config.n_splits {
289 let mut min_distance = f64::INFINITY;
290 let mut merge_i = 0;
291 let mut merge_j = 0;
292
293 for i in 0..clusters.len() {
294 for j in i + 1..clusters.len() {
295 let dist = distances[clusters[i]][clusters[j]];
296 if dist < min_distance {
297 min_distance = dist;
298 merge_i = i;
299 merge_j = j;
300 }
301 }
302 }
303
304 let cluster_j = clusters.remove(merge_j);
306 let cluster_i = clusters[merge_i];
307
308 for assignment in &mut cluster_map {
310 if *assignment == cluster_j {
311 *assignment = cluster_i;
312 }
313 }
314 }
315
316 let unique_clusters: Vec<_> = cluster_map
318 .iter()
319 .cloned()
320 .collect::<HashSet<_>>()
321 .into_iter()
322 .collect();
323 let mut final_clusters = vec![0; n_samples];
324
325 for (i, &cluster_id) in cluster_map.iter().enumerate() {
326 final_clusters[i] = unique_clusters
327 .iter()
328 .position(|&x| x == cluster_id)
329 .unwrap_or(0);
330 }
331
332 Ok(final_clusters)
333 }
334
335 fn dbscan_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
337 let n_samples = coordinates.len();
338 let eps = self.config.buffer_distance / 2.0;
339 let min_pts = (n_samples / self.config.n_splits).max(2);
340
341 let mut clusters = vec![None; n_samples];
342 let mut visited = vec![false; n_samples];
343 let mut cluster_id = 0;
344
345 for i in 0..n_samples {
346 if visited[i] {
347 continue;
348 }
349
350 visited[i] = true;
351 let neighbors = self.find_neighbors(i, coordinates, eps);
352
353 if neighbors.len() < min_pts {
354 clusters[i] = Some(usize::MAX); } else {
356 self.expand_cluster(
357 i,
358 &neighbors,
359 cluster_id,
360 coordinates,
361 eps,
362 min_pts,
363 &mut clusters,
364 &mut visited,
365 );
366 cluster_id += 1;
367 }
368 }
369
370 let max_clusters = cluster_id.min(self.config.n_splits);
372 let mut final_clusters = vec![0; n_samples];
373
374 for (i, cluster) in clusters.iter().enumerate() {
375 final_clusters[i] = match cluster {
376 Some(id) if *id != usize::MAX => *id % max_clusters,
377 _ => i % self.config.n_splits, };
379 }
380
381 Ok(final_clusters)
382 }
383
384 fn find_neighbors(
385 &self,
386 point: usize,
387 coordinates: &[SpatialCoordinate],
388 eps: f64,
389 ) -> Vec<usize> {
390 let mut neighbors = Vec::new();
391 for (i, coord) in coordinates.iter().enumerate() {
392 if i != point && self.calculate_distance(&coordinates[point], coord) <= eps {
393 neighbors.push(i);
394 }
395 }
396 neighbors
397 }
398
399 #[allow(clippy::too_many_arguments)]
400 fn expand_cluster(
401 &self,
402 point: usize,
403 neighbors: &[usize],
404 cluster_id: usize,
405 coordinates: &[SpatialCoordinate],
406 eps: f64,
407 min_pts: usize,
408 clusters: &mut [Option<usize>],
409 visited: &mut [bool],
410 ) {
411 clusters[point] = Some(cluster_id);
412 let mut seed_set = neighbors.to_vec();
413 let mut i = 0;
414
415 while i < seed_set.len() {
416 let q = seed_set[i];
417
418 if !visited[q] {
419 visited[q] = true;
420 let q_neighbors = self.find_neighbors(q, coordinates, eps);
421
422 if q_neighbors.len() >= min_pts {
423 seed_set.extend(q_neighbors);
424 }
425 }
426
427 if clusters[q].is_none() {
428 clusters[q] = Some(cluster_id);
429 }
430
431 i += 1;
432 }
433 }
434
435 fn generate_cluster_splits(&self, clusters: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
437 let mut splits = Vec::new();
438
439 for test_cluster in 0..self.config.n_splits {
440 let mut train_indices = Vec::new();
441 let mut test_indices = Vec::new();
442
443 for (i, &cluster) in clusters.iter().enumerate() {
444 if cluster == test_cluster {
445 test_indices.push(i);
446 } else {
447 train_indices.push(i);
448 }
449 }
450
451 if !train_indices.is_empty() && !test_indices.is_empty() {
452 splits.push((train_indices, test_indices));
453 }
454 }
455
456 Ok(splits)
457 }
458
459 fn apply_buffer_constraints(
461 &self,
462 splits: &[(Vec<usize>, Vec<usize>)],
463 coordinates: &[SpatialCoordinate],
464 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
465 let mut filtered_splits = Vec::new();
466
467 for (train_indices, test_indices) in splits {
468 let mut filtered_train = Vec::new();
469
470 for &train_idx in train_indices {
471 let mut too_close = false;
472
473 for &test_idx in test_indices {
474 let distance =
475 self.calculate_distance(&coordinates[train_idx], &coordinates[test_idx]);
476
477 if distance < self.config.buffer_distance {
478 too_close = true;
479 break;
480 }
481 }
482
483 if !too_close {
484 filtered_train.push(train_idx);
485 }
486 }
487
488 if !filtered_train.is_empty() && !test_indices.is_empty() {
489 filtered_splits.push((filtered_train, test_indices.clone()));
490 }
491 }
492
493 Ok(filtered_splits)
494 }
495
496 fn calculate_distance(&self, coord1: &SpatialCoordinate, coord2: &SpatialCoordinate) -> f64 {
498 match self.config.distance_method {
499 DistanceMethod::Euclidean => coord1.distance(coord2),
500 DistanceMethod::Haversine => coord1.haversine_distance(coord2),
501 DistanceMethod::Manhattan => (coord1.x - coord2.x).abs() + (coord1.y - coord2.y).abs(),
502 DistanceMethod::Chebyshev => {
503 (coord1.x - coord2.x).abs().max((coord1.y - coord2.y).abs())
504 }
505 }
506 }
507
508 fn get_rng(&self) -> impl scirs2_core::random::Rng {
509 use scirs2_core::random::rngs::StdRng;
510 use scirs2_core::random::SeedableRng;
511 match self.config.random_state {
512 Some(seed) => StdRng::seed_from_u64(seed),
513 None => StdRng::seed_from_u64(42),
514 }
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct LeaveOneRegionOut {
521 region_labels: Vec<usize>,
522}
523
524impl LeaveOneRegionOut {
525 pub fn new(region_labels: Vec<usize>) -> Self {
526 Self { region_labels }
527 }
528
529 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
531 if self.region_labels.len() != n_samples {
532 return Err(SklearsError::InvalidInput(
533 "Region labels length must match number of samples".to_string(),
534 ));
535 }
536
537 let unique_regions: HashSet<usize> = self.region_labels.iter().cloned().collect();
538 let mut splits = Vec::new();
539
540 for test_region in unique_regions {
541 let mut train_indices = Vec::new();
542 let mut test_indices = Vec::new();
543
544 for (i, ®ion) in self.region_labels.iter().enumerate() {
545 if region == test_region {
546 test_indices.push(i);
547 } else {
548 train_indices.push(i);
549 }
550 }
551
552 if !train_indices.is_empty() && !test_indices.is_empty() {
553 splits.push((train_indices, test_indices));
554 }
555 }
556
557 Ok(splits)
558 }
559}
560
561#[allow(non_snake_case)]
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_spatial_coordinate_distance() {
568 let coord1 = SpatialCoordinate::new(0.0, 0.0);
569 let coord2 = SpatialCoordinate::new(3.0, 4.0);
570
571 assert!((coord1.distance(&coord2) - 5.0).abs() < 1e-10);
572 }
573
574 #[test]
575 fn test_spatial_cross_validator() {
576 let config = SpatialValidationConfig {
577 buffer_distance: 1.0, ..Default::default()
579 };
580 let cv = SpatialCrossValidator::new(config);
581
582 let mut coordinates = Vec::new();
584 for i in 0..25 {
585 let x = (i % 5) as f64;
586 let y = (i / 5) as f64;
587 coordinates.push(SpatialCoordinate::new(x, y));
588 }
589
590 let splits = cv.split(25, &coordinates).unwrap();
591 assert!(!splits.is_empty(), "Should generate at least one split");
592
593 for (train_indices, test_indices) in &splits {
594 assert!(
595 !train_indices.is_empty(),
596 "Training set should not be empty"
597 );
598 assert!(!test_indices.is_empty(), "Test set should not be empty");
599 }
600 }
601
602 #[test]
603 fn test_leave_one_region_out() {
604 let region_labels = vec![0, 0, 1, 1, 2, 2];
605 let cv = LeaveOneRegionOut::new(region_labels);
606
607 let splits = cv.split(6).unwrap();
608 assert_eq!(splits.len(), 3, "Should have 3 splits for 3 regions");
609
610 for (train_indices, test_indices) in &splits {
611 assert!(
612 !train_indices.is_empty(),
613 "Training set should not be empty"
614 );
615 assert!(!test_indices.is_empty(), "Test set should not be empty");
616 }
617 }
618}