1use scirs2_core::RngExt;
7use sklears_core::error::{Result, SklearsError};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone, Copy)]
12pub struct SpatialCoordinate {
13 pub x: f64,
14 pub y: f64,
15 pub z: Option<f64>, }
17
18impl SpatialCoordinate {
19 pub fn new(x: f64, y: f64) -> Self {
20 Self { x, y, z: None }
21 }
22
23 pub fn new_3d(x: f64, y: f64, z: f64) -> Self {
24 Self { x, y, z: Some(z) }
25 }
26
27 pub fn distance(&self, other: &SpatialCoordinate) -> f64 {
29 let dx = self.x - other.x;
30 let dy = self.y - other.y;
31 let dz = match (self.z, other.z) {
32 (Some(z1), Some(z2)) => z1 - z2,
33 _ => 0.0,
34 };
35 (dx * dx + dy * dy + dz * dz).sqrt()
36 }
37
38 pub fn haversine_distance(&self, other: &SpatialCoordinate) -> f64 {
40 const EARTH_RADIUS_KM: f64 = 6371.0;
41
42 let lat1 = self.y.to_radians();
43 let lat2 = other.y.to_radians();
44 let delta_lat = (other.y - self.y).to_radians();
45 let delta_lon = (other.x - self.x).to_radians();
46
47 let a = (delta_lat / 2.0).sin().powi(2)
48 + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
49 let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
50
51 EARTH_RADIUS_KM * c
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct SpatialValidationConfig {
58 pub n_splits: usize,
60 pub buffer_distance: f64,
62 pub distance_method: DistanceMethod,
64 pub clustering_method: SpatialClusteringMethod,
66 pub random_state: Option<u64>,
68 pub geographic: bool,
70}
71
72impl Default for SpatialValidationConfig {
73 fn default() -> Self {
74 Self {
75 n_splits: 5,
76 buffer_distance: 1000.0, distance_method: DistanceMethod::Euclidean,
78 clustering_method: SpatialClusteringMethod::KMeans,
79 random_state: None,
80 geographic: false,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub enum DistanceMethod {
88 Euclidean,
90 Haversine, Manhattan,
94 Chebyshev,
96}
97
98#[derive(Debug, Clone)]
100pub enum SpatialClusteringMethod {
101 KMeans,
103 Grid,
105 Hierarchical,
107 DBSCAN,
109}
110
111#[derive(Debug, Clone)]
113pub struct SpatialCrossValidator {
114 config: SpatialValidationConfig,
115}
116
117impl SpatialCrossValidator {
118 pub fn new(config: SpatialValidationConfig) -> Self {
119 Self { config }
120 }
121
122 pub fn split(
124 &self,
125 n_samples: usize,
126 coordinates: &[SpatialCoordinate],
127 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
128 if coordinates.len() != n_samples {
129 return Err(SklearsError::InvalidInput(
130 "Number of coordinates must match number of samples".to_string(),
131 ));
132 }
133
134 let clusters = self.create_spatial_clusters(coordinates)?;
136
137 let splits = self.generate_cluster_splits(&clusters)?;
139
140 let filtered_splits = self.apply_buffer_constraints(&splits, coordinates)?;
142
143 Ok(filtered_splits)
144 }
145
146 fn create_spatial_clusters(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
148 match self.config.clustering_method {
149 SpatialClusteringMethod::KMeans => self.kmeans_clustering(coordinates),
150 SpatialClusteringMethod::Grid => self.grid_clustering(coordinates),
151 SpatialClusteringMethod::Hierarchical => self.hierarchical_clustering(coordinates),
152 SpatialClusteringMethod::DBSCAN => self.dbscan_clustering(coordinates),
153 }
154 }
155
156 fn kmeans_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
158 let n_samples = coordinates.len();
159 let mut clusters = vec![0; n_samples];
160 let mut centroids = Vec::new();
161
162 let mut rng = self.get_rng();
164 for _i in 0..self.config.n_splits {
165 let idx = rng.random_range(0..n_samples);
166 centroids.push(coordinates[idx]);
167 }
168
169 for _ in 0..100 {
171 let mut new_centroids = vec![SpatialCoordinate::new(0.0, 0.0); self.config.n_splits];
173 let mut cluster_counts = vec![0; self.config.n_splits];
174 let mut changed = false;
175
176 for (i, coord) in coordinates.iter().enumerate() {
178 let mut min_distance = f64::INFINITY;
179 let mut best_cluster = 0;
180
181 for (j, centroid) in centroids.iter().enumerate() {
182 let distance = self.calculate_distance(coord, centroid);
183 if distance < min_distance {
184 min_distance = distance;
185 best_cluster = j;
186 }
187 }
188
189 if clusters[i] != best_cluster {
190 changed = true;
191 clusters[i] = best_cluster;
192 }
193
194 new_centroids[best_cluster].x += coord.x;
196 new_centroids[best_cluster].y += coord.y;
197 if let Some(z) = coord.z {
198 if new_centroids[best_cluster].z.is_none() {
199 new_centroids[best_cluster].z = Some(0.0);
200 }
201 new_centroids[best_cluster].z = Some(
202 new_centroids[best_cluster]
203 .z
204 .expect("operation should succeed")
205 + z,
206 );
207 }
208 cluster_counts[best_cluster] += 1;
209 }
210
211 for (i, count) in cluster_counts.iter().enumerate() {
213 if *count > 0 {
214 new_centroids[i].x /= *count as f64;
215 new_centroids[i].y /= *count as f64;
216 if let Some(z) = new_centroids[i].z {
217 new_centroids[i].z = Some(z / *count as f64);
218 }
219 }
220 }
221
222 centroids = new_centroids;
223
224 if !changed {
225 break;
226 }
227 }
228
229 Ok(clusters)
230 }
231
232 fn grid_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
234 let min_x = coordinates
236 .iter()
237 .map(|c| c.x)
238 .fold(f64::INFINITY, f64::min);
239 let max_x = coordinates
240 .iter()
241 .map(|c| c.x)
242 .fold(f64::NEG_INFINITY, f64::max);
243 let min_y = coordinates
244 .iter()
245 .map(|c| c.y)
246 .fold(f64::INFINITY, f64::min);
247 let max_y = coordinates
248 .iter()
249 .map(|c| c.y)
250 .fold(f64::NEG_INFINITY, f64::max);
251
252 let grid_size = (self.config.n_splits as f64).sqrt().ceil() as usize;
254 let x_step = (max_x - min_x) / grid_size as f64;
255 let y_step = (max_y - min_y) / grid_size as f64;
256
257 let mut clusters = Vec::new();
258
259 for coord in coordinates {
260 let x_grid = ((coord.x - min_x) / x_step).floor() as usize;
261 let y_grid = ((coord.y - min_y) / y_step).floor() as usize;
262
263 let x_grid = x_grid.min(grid_size - 1);
264 let y_grid = y_grid.min(grid_size - 1);
265
266 let cluster_id = (y_grid * grid_size + x_grid) % self.config.n_splits;
267 clusters.push(cluster_id);
268 }
269
270 Ok(clusters)
271 }
272
273 fn hierarchical_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
275 let n_samples = coordinates.len();
276
277 let mut distances = vec![vec![0.0; n_samples]; n_samples];
279 for i in 0..n_samples {
280 for j in i + 1..n_samples {
281 let dist = self.calculate_distance(&coordinates[i], &coordinates[j]);
282 distances[i][j] = dist;
283 distances[j][i] = dist;
284 }
285 }
286
287 let mut clusters = (0..n_samples).collect::<Vec<_>>();
289 let mut cluster_map = (0..n_samples).collect::<Vec<_>>();
290
291 while clusters.len() > self.config.n_splits {
293 let mut min_distance = f64::INFINITY;
294 let mut merge_i = 0;
295 let mut merge_j = 0;
296
297 for i in 0..clusters.len() {
298 for j in i + 1..clusters.len() {
299 let dist = distances[clusters[i]][clusters[j]];
300 if dist < min_distance {
301 min_distance = dist;
302 merge_i = i;
303 merge_j = j;
304 }
305 }
306 }
307
308 let cluster_j = clusters.remove(merge_j);
310 let cluster_i = clusters[merge_i];
311
312 for assignment in &mut cluster_map {
314 if *assignment == cluster_j {
315 *assignment = cluster_i;
316 }
317 }
318 }
319
320 let unique_clusters: Vec<_> = cluster_map
322 .iter()
323 .cloned()
324 .collect::<HashSet<_>>()
325 .into_iter()
326 .collect();
327 let mut final_clusters = vec![0; n_samples];
328
329 for (i, &cluster_id) in cluster_map.iter().enumerate() {
330 final_clusters[i] = unique_clusters
331 .iter()
332 .position(|&x| x == cluster_id)
333 .unwrap_or(0);
334 }
335
336 Ok(final_clusters)
337 }
338
339 fn dbscan_clustering(&self, coordinates: &[SpatialCoordinate]) -> Result<Vec<usize>> {
341 let n_samples = coordinates.len();
342 let eps = self.config.buffer_distance / 2.0;
343 let min_pts = (n_samples / self.config.n_splits).max(2);
344
345 let mut clusters = vec![None; n_samples];
346 let mut visited = vec![false; n_samples];
347 let mut cluster_id = 0;
348
349 for i in 0..n_samples {
350 if visited[i] {
351 continue;
352 }
353
354 visited[i] = true;
355 let neighbors = self.find_neighbors(i, coordinates, eps);
356
357 if neighbors.len() < min_pts {
358 clusters[i] = Some(usize::MAX); } else {
360 self.expand_cluster(
361 i,
362 &neighbors,
363 cluster_id,
364 coordinates,
365 eps,
366 min_pts,
367 &mut clusters,
368 &mut visited,
369 );
370 cluster_id += 1;
371 }
372 }
373
374 let max_clusters = cluster_id.min(self.config.n_splits);
376 let mut final_clusters = vec![0; n_samples];
377
378 for (i, cluster) in clusters.iter().enumerate() {
379 final_clusters[i] = match cluster {
380 Some(id) if *id != usize::MAX => *id % max_clusters,
381 _ => i % self.config.n_splits, };
383 }
384
385 Ok(final_clusters)
386 }
387
388 fn find_neighbors(
389 &self,
390 point: usize,
391 coordinates: &[SpatialCoordinate],
392 eps: f64,
393 ) -> Vec<usize> {
394 let mut neighbors = Vec::new();
395 for (i, coord) in coordinates.iter().enumerate() {
396 if i != point && self.calculate_distance(&coordinates[point], coord) <= eps {
397 neighbors.push(i);
398 }
399 }
400 neighbors
401 }
402
403 #[allow(clippy::too_many_arguments)]
404 fn expand_cluster(
405 &self,
406 point: usize,
407 neighbors: &[usize],
408 cluster_id: usize,
409 coordinates: &[SpatialCoordinate],
410 eps: f64,
411 min_pts: usize,
412 clusters: &mut [Option<usize>],
413 visited: &mut [bool],
414 ) {
415 clusters[point] = Some(cluster_id);
416 let mut seed_set = neighbors.to_vec();
417 let mut i = 0;
418
419 while i < seed_set.len() {
420 let q = seed_set[i];
421
422 if !visited[q] {
423 visited[q] = true;
424 let q_neighbors = self.find_neighbors(q, coordinates, eps);
425
426 if q_neighbors.len() >= min_pts {
427 seed_set.extend(q_neighbors);
428 }
429 }
430
431 if clusters[q].is_none() {
432 clusters[q] = Some(cluster_id);
433 }
434
435 i += 1;
436 }
437 }
438
439 fn generate_cluster_splits(&self, clusters: &[usize]) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
441 let mut splits = Vec::new();
442
443 for test_cluster in 0..self.config.n_splits {
444 let mut train_indices = Vec::new();
445 let mut test_indices = Vec::new();
446
447 for (i, &cluster) in clusters.iter().enumerate() {
448 if cluster == test_cluster {
449 test_indices.push(i);
450 } else {
451 train_indices.push(i);
452 }
453 }
454
455 if !train_indices.is_empty() && !test_indices.is_empty() {
456 splits.push((train_indices, test_indices));
457 }
458 }
459
460 Ok(splits)
461 }
462
463 fn apply_buffer_constraints(
465 &self,
466 splits: &[(Vec<usize>, Vec<usize>)],
467 coordinates: &[SpatialCoordinate],
468 ) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
469 let mut filtered_splits = Vec::new();
470
471 for (train_indices, test_indices) in splits {
472 let mut filtered_train = Vec::new();
473
474 for &train_idx in train_indices {
475 let mut too_close = false;
476
477 for &test_idx in test_indices {
478 let distance =
479 self.calculate_distance(&coordinates[train_idx], &coordinates[test_idx]);
480
481 if distance < self.config.buffer_distance {
482 too_close = true;
483 break;
484 }
485 }
486
487 if !too_close {
488 filtered_train.push(train_idx);
489 }
490 }
491
492 if !filtered_train.is_empty() && !test_indices.is_empty() {
493 filtered_splits.push((filtered_train, test_indices.clone()));
494 }
495 }
496
497 Ok(filtered_splits)
498 }
499
500 fn calculate_distance(&self, coord1: &SpatialCoordinate, coord2: &SpatialCoordinate) -> f64 {
502 match self.config.distance_method {
503 DistanceMethod::Euclidean => coord1.distance(coord2),
504 DistanceMethod::Haversine => coord1.haversine_distance(coord2),
505 DistanceMethod::Manhattan => (coord1.x - coord2.x).abs() + (coord1.y - coord2.y).abs(),
506 DistanceMethod::Chebyshev => {
507 (coord1.x - coord2.x).abs().max((coord1.y - coord2.y).abs())
508 }
509 }
510 }
511
512 fn get_rng(&self) -> impl scirs2_core::random::Rng {
513 use scirs2_core::random::rngs::StdRng;
514 use scirs2_core::random::SeedableRng;
515 match self.config.random_state {
516 Some(seed) => StdRng::seed_from_u64(seed),
517 None => StdRng::seed_from_u64(42),
518 }
519 }
520}
521
522#[derive(Debug, Clone)]
524pub struct LeaveOneRegionOut {
525 region_labels: Vec<usize>,
526}
527
528impl LeaveOneRegionOut {
529 pub fn new(region_labels: Vec<usize>) -> Self {
530 Self { region_labels }
531 }
532
533 pub fn split(&self, n_samples: usize) -> Result<Vec<(Vec<usize>, Vec<usize>)>> {
535 if self.region_labels.len() != n_samples {
536 return Err(SklearsError::InvalidInput(
537 "Region labels length must match number of samples".to_string(),
538 ));
539 }
540
541 let unique_regions: HashSet<usize> = self.region_labels.iter().cloned().collect();
542 let mut splits = Vec::new();
543
544 for test_region in unique_regions {
545 let mut train_indices = Vec::new();
546 let mut test_indices = Vec::new();
547
548 for (i, ®ion) in self.region_labels.iter().enumerate() {
549 if region == test_region {
550 test_indices.push(i);
551 } else {
552 train_indices.push(i);
553 }
554 }
555
556 if !train_indices.is_empty() && !test_indices.is_empty() {
557 splits.push((train_indices, test_indices));
558 }
559 }
560
561 Ok(splits)
562 }
563}
564
565#[allow(non_snake_case)]
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_spatial_coordinate_distance() {
572 let coord1 = SpatialCoordinate::new(0.0, 0.0);
573 let coord2 = SpatialCoordinate::new(3.0, 4.0);
574
575 assert!((coord1.distance(&coord2) - 5.0).abs() < 1e-10);
576 }
577
578 #[test]
579 fn test_spatial_cross_validator() {
580 let config = SpatialValidationConfig {
581 buffer_distance: 1.0, ..Default::default()
583 };
584 let cv = SpatialCrossValidator::new(config);
585
586 let mut coordinates = Vec::new();
588 for i in 0..25 {
589 let x = (i % 5) as f64;
590 let y = (i / 5) as f64;
591 coordinates.push(SpatialCoordinate::new(x, y));
592 }
593
594 let splits = cv
595 .split(25, &coordinates)
596 .expect("operation should succeed");
597 assert!(!splits.is_empty(), "Should generate at least one split");
598
599 for (train_indices, test_indices) in &splits {
600 assert!(
601 !train_indices.is_empty(),
602 "Training set should not be empty"
603 );
604 assert!(!test_indices.is_empty(), "Test set should not be empty");
605 }
606 }
607
608 #[test]
609 fn test_leave_one_region_out() {
610 let region_labels = vec![0, 0, 1, 1, 2, 2];
611 let cv = LeaveOneRegionOut::new(region_labels);
612
613 let splits = cv.split(6).expect("operation should succeed");
614 assert_eq!(splits.len(), 3, "Should have 3 splits for 3 regions");
615
616 for (train_indices, test_indices) in &splits {
617 assert!(
618 !train_indices.is_empty(),
619 "Training set should not be empty"
620 );
621 assert!(!test_indices.is_empty(), "Test set should not be empty");
622 }
623 }
624}