1use scirs2_core::ndarray::{ArrayView1, ArrayView2};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::SliceRandomExt;
11use sklears_core::prelude::*;
12use std::collections::{HashMap, HashSet};
13
14fn hierarchical_error(msg: &str) -> SklearsError {
15 SklearsError::InvalidInput(msg.to_string())
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum HierarchicalStrategy {
20 ClusterBased,
22 NestedCV,
24 MultilevelBootstrap,
26 HierarchicalKFold,
28 LeaveOneClusterOut,
30}
31
32#[derive(Debug, Clone)]
33pub struct HierarchicalValidationConfig {
34 pub strategy: HierarchicalStrategy,
35 pub n_folds: usize,
36 pub random_state: Option<u64>,
37 pub shuffle: bool,
38 pub balance_clusters: bool,
39 pub min_cluster_size: usize,
40 pub max_imbalance_ratio: f64,
41}
42
43impl Default for HierarchicalValidationConfig {
44 fn default() -> Self {
45 Self {
46 strategy: HierarchicalStrategy::ClusterBased,
47 n_folds: 5,
48 random_state: None,
49 shuffle: true,
50 balance_clusters: false,
51 min_cluster_size: 1,
52 max_imbalance_ratio: 2.0,
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct ClusterInfo {
59 pub cluster_id: String,
60 pub size: usize,
61 pub indices: Vec<usize>,
62 pub level: usize,
63 pub parent_cluster: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct HierarchicalSplit {
68 pub train_indices: Vec<usize>,
69 pub test_indices: Vec<usize>,
70 pub train_clusters: Vec<String>,
71 pub test_clusters: Vec<String>,
72 pub fold_id: usize,
73}
74
75pub struct HierarchicalCrossValidator {
76 config: HierarchicalValidationConfig,
77 clusters: HashMap<String, ClusterInfo>,
78 hierarchy_levels: usize,
79 rng: StdRng,
80}
81
82impl HierarchicalCrossValidator {
83 pub fn new(config: HierarchicalValidationConfig) -> Self {
84 let rng = if let Some(seed) = config.random_state {
85 StdRng::seed_from_u64(seed)
86 } else {
87 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
88 };
89
90 Self {
91 config,
92 clusters: HashMap::new(),
93 hierarchy_levels: 0,
94 rng,
95 }
96 }
97
98 pub fn with_cluster_labels(mut self, labels: &[String]) -> Result<Self> {
99 if labels.is_empty() {
100 return Err(hierarchical_error("Empty cluster labels"));
101 }
102
103 self.build_clusters(labels)?;
104 Ok(self)
105 }
106
107 pub fn with_hierarchical_labels(mut self, labels: &[Vec<String>]) -> Result<Self> {
108 if labels.is_empty() {
109 return Err(hierarchical_error("Empty cluster labels"));
110 }
111
112 self.build_hierarchical_clusters(labels)?;
113 Ok(self)
114 }
115
116 fn build_clusters(&mut self, labels: &[String]) -> Result<()> {
117 let mut cluster_indices: HashMap<String, Vec<usize>> = HashMap::new();
118
119 for (idx, label) in labels.iter().enumerate() {
120 cluster_indices.entry(label.clone()).or_default().push(idx);
121 }
122
123 if cluster_indices.len() < self.config.n_folds {
124 return Err(hierarchical_error(&format!(
125 "Insufficient clusters for {} folds: got {}",
126 self.config.n_folds,
127 cluster_indices.len()
128 )));
129 }
130
131 for (cluster_id, indices) in cluster_indices {
132 if indices.len() >= self.config.min_cluster_size {
133 self.clusters.insert(
134 cluster_id.clone(),
135 ClusterInfo {
136 cluster_id: cluster_id.clone(),
137 size: indices.len(),
138 indices,
139 level: 0,
140 parent_cluster: None,
141 },
142 );
143 }
144 }
145
146 self.hierarchy_levels = 1;
147 Ok(())
148 }
149
150 fn build_hierarchical_clusters(&mut self, labels: &[Vec<String>]) -> Result<()> {
151 if labels
152 .iter()
153 .any(|level_labels| level_labels.len() != labels[0].len())
154 {
155 return Err(hierarchical_error("Unbalanced hierarchy levels"));
156 }
157
158 self.hierarchy_levels = labels.len();
159 let mut level_clusters: Vec<HashMap<String, Vec<usize>>> = Vec::new();
160
161 for level_labels in labels.iter().take(self.hierarchy_levels) {
162 let mut cluster_indices: HashMap<String, Vec<usize>> = HashMap::new();
163
164 for (idx, label) in level_labels.iter().enumerate() {
165 cluster_indices.entry(label.clone()).or_default().push(idx);
166 }
167
168 level_clusters.push(cluster_indices);
169 }
170
171 for (level, cluster_indices) in level_clusters.into_iter().enumerate() {
172 for (cluster_id, indices) in cluster_indices {
173 if indices.len() >= self.config.min_cluster_size {
174 let parent_cluster = if level > 0 {
175 Some(labels[level - 1][indices[0]].clone())
176 } else {
177 None
178 };
179
180 self.clusters.insert(
181 format!("{}_{}", level, cluster_id),
182 ClusterInfo {
183 cluster_id: cluster_id.clone(),
184 size: indices.len(),
185 indices,
186 level,
187 parent_cluster,
188 },
189 );
190 }
191 }
192 }
193
194 Ok(())
195 }
196
197 pub fn split(&mut self, n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
198 match self.config.strategy {
199 HierarchicalStrategy::ClusterBased => self.cluster_based_split(n_samples),
200 HierarchicalStrategy::NestedCV => self.nested_cv_split(n_samples),
201 HierarchicalStrategy::MultilevelBootstrap => self.multilevel_bootstrap_split(n_samples),
202 HierarchicalStrategy::HierarchicalKFold => self.hierarchical_kfold_split(n_samples),
203 HierarchicalStrategy::LeaveOneClusterOut => self.leave_one_cluster_out_split(n_samples),
204 }
205 }
206
207 fn cluster_based_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
208 let top_level_clusters: Vec<_> = self
209 .clusters
210 .values()
211 .filter(|c| c.level == 0 || (self.hierarchy_levels == 1))
212 .collect();
213
214 if top_level_clusters.len() < self.config.n_folds {
215 return Err(hierarchical_error(&format!(
216 "Insufficient clusters for {} folds: got {}",
217 self.config.n_folds,
218 top_level_clusters.len()
219 )));
220 }
221
222 let mut cluster_ids: Vec<String> = top_level_clusters
223 .iter()
224 .map(|c| c.cluster_id.clone())
225 .collect();
226
227 if self.config.shuffle {
228 cluster_ids.shuffle(&mut self.rng);
229 }
230
231 if self.config.balance_clusters {
232 cluster_ids.sort_by_key(|id| self.clusters[id].size);
233 }
234
235 let mut splits = Vec::new();
236 let clusters_per_fold = cluster_ids.len() / self.config.n_folds;
237 let remainder = cluster_ids.len() % self.config.n_folds;
238
239 for fold in 0..self.config.n_folds {
240 let start_idx = fold * clusters_per_fold + (fold.min(remainder));
241 let end_idx = start_idx + clusters_per_fold + if fold < remainder { 1 } else { 0 };
242
243 let test_clusters = cluster_ids[start_idx..end_idx].to_vec();
244 let train_clusters: Vec<String> = cluster_ids
245 .iter()
246 .filter(|&id| !test_clusters.contains(id))
247 .cloned()
248 .collect();
249
250 let mut train_indices = Vec::new();
251 let mut test_indices = Vec::new();
252
253 for cluster_id in &train_clusters {
254 train_indices.extend(&self.clusters[cluster_id].indices);
255 }
256
257 for cluster_id in &test_clusters {
258 test_indices.extend(&self.clusters[cluster_id].indices);
259 }
260
261 train_indices.sort_unstable();
262 test_indices.sort_unstable();
263
264 splits.push(HierarchicalSplit {
265 train_indices,
266 test_indices,
267 train_clusters,
268 test_clusters,
269 fold_id: fold,
270 });
271 }
272
273 Ok(splits)
274 }
275
276 fn nested_cv_split(&mut self, n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
277 if self.hierarchy_levels < 2 {
278 return self.cluster_based_split(n_samples);
279 }
280
281 let mut splits = Vec::new();
282 let top_level_clusters: Vec<_> = self.clusters.values().filter(|c| c.level == 0).collect();
283
284 for (outer_fold, test_cluster) in top_level_clusters.iter().enumerate() {
285 let train_clusters: Vec<_> = top_level_clusters
286 .iter()
287 .filter(|c| c.cluster_id != test_cluster.cluster_id)
288 .collect();
289
290 let mut train_indices = Vec::new();
291 let mut test_indices = Vec::new();
292 let mut train_cluster_ids = Vec::new();
293 let test_cluster_ids = vec![test_cluster.cluster_id.clone()];
294
295 for cluster in train_clusters {
296 train_indices.extend(&cluster.indices);
297 train_cluster_ids.push(cluster.cluster_id.clone());
298 }
299
300 test_indices.extend(&test_cluster.indices);
301
302 train_indices.sort_unstable();
303 test_indices.sort_unstable();
304
305 splits.push(HierarchicalSplit {
306 train_indices,
307 test_indices,
308 train_clusters: train_cluster_ids,
309 test_clusters: test_cluster_ids,
310 fold_id: outer_fold,
311 });
312 }
313
314 Ok(splits)
315 }
316
317 fn multilevel_bootstrap_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
318 let mut splits = Vec::new();
319 let all_clusters: Vec<_> = self.clusters.values().collect();
320
321 for fold in 0..self.config.n_folds {
322 let mut bootstrap_clusters = Vec::new();
323 let n_bootstrap = (all_clusters.len() as f64 * 0.632).ceil() as usize;
324
325 for _ in 0..n_bootstrap {
326 let idx = self.rng.gen_range(0..all_clusters.len());
327 bootstrap_clusters.push(all_clusters[idx].cluster_id.clone());
328 }
329
330 let bootstrap_set: HashSet<String> = bootstrap_clusters.iter().cloned().collect();
331 let oob_clusters: Vec<String> = all_clusters
332 .iter()
333 .filter(|c| !bootstrap_set.contains(&c.cluster_id))
334 .map(|c| c.cluster_id.clone())
335 .collect();
336
337 let mut train_indices = Vec::new();
338 let mut test_indices = Vec::new();
339
340 for cluster_id in &bootstrap_clusters {
341 train_indices.extend(&self.clusters[cluster_id].indices);
342 }
343
344 for cluster_id in &oob_clusters {
345 test_indices.extend(&self.clusters[cluster_id].indices);
346 }
347
348 train_indices.sort_unstable();
349 test_indices.sort_unstable();
350
351 splits.push(HierarchicalSplit {
352 train_indices,
353 test_indices,
354 train_clusters: bootstrap_clusters,
355 test_clusters: oob_clusters,
356 fold_id: fold,
357 });
358 }
359
360 Ok(splits)
361 }
362
363 fn hierarchical_kfold_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
364 let mut splits = Vec::new();
365
366 for level in 0..self.hierarchy_levels {
367 let level_clusters: Vec<_> = self
368 .clusters
369 .values()
370 .filter(|c| c.level == level)
371 .collect();
372
373 if level_clusters.len() < self.config.n_folds {
374 continue;
375 }
376
377 let mut cluster_ids: Vec<String> = level_clusters
378 .iter()
379 .map(|c| c.cluster_id.clone())
380 .collect();
381
382 if self.config.shuffle {
383 cluster_ids.shuffle(&mut self.rng);
384 }
385
386 let clusters_per_fold = cluster_ids.len() / self.config.n_folds;
387
388 for fold in 0..self.config.n_folds {
389 let start_idx = fold * clusters_per_fold;
390 let end_idx = if fold == self.config.n_folds - 1 {
391 cluster_ids.len()
392 } else {
393 (fold + 1) * clusters_per_fold
394 };
395
396 let test_clusters = cluster_ids[start_idx..end_idx].to_vec();
397 let train_clusters: Vec<String> = cluster_ids
398 .iter()
399 .filter(|&id| !test_clusters.contains(id))
400 .cloned()
401 .collect();
402
403 let mut train_indices = Vec::new();
404 let mut test_indices = Vec::new();
405
406 for cluster_id in &train_clusters {
407 if let Some(cluster) = self.clusters.get(cluster_id) {
408 train_indices.extend(&cluster.indices);
409 }
410 }
411
412 for cluster_id in &test_clusters {
413 if let Some(cluster) = self.clusters.get(cluster_id) {
414 test_indices.extend(&cluster.indices);
415 }
416 }
417
418 train_indices.sort_unstable();
419 test_indices.sort_unstable();
420
421 splits.push(HierarchicalSplit {
422 train_indices,
423 test_indices,
424 train_clusters,
425 test_clusters,
426 fold_id: fold + level * self.config.n_folds,
427 });
428 }
429 }
430
431 Ok(splits)
432 }
433
434 fn leave_one_cluster_out_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
435 let mut splits = Vec::new();
436 let all_clusters: Vec<_> = self.clusters.values().collect();
437
438 for (fold, test_cluster) in all_clusters.iter().enumerate() {
439 let train_clusters: Vec<String> = all_clusters
440 .iter()
441 .filter(|c| c.cluster_id != test_cluster.cluster_id)
442 .map(|c| c.cluster_id.clone())
443 .collect();
444
445 let mut train_indices = Vec::new();
446 let test_indices = test_cluster.indices.clone();
447
448 for cluster_id in &train_clusters {
449 train_indices.extend(&self.clusters[cluster_id].indices);
450 }
451
452 train_indices.sort_unstable();
453
454 splits.push(HierarchicalSplit {
455 train_indices,
456 test_indices,
457 train_clusters,
458 test_clusters: vec![test_cluster.cluster_id.clone()],
459 fold_id: fold,
460 });
461 }
462
463 Ok(splits)
464 }
465
466 pub fn get_n_splits(&self) -> usize {
467 match self.config.strategy {
468 HierarchicalStrategy::LeaveOneClusterOut => self.clusters.len(),
469 HierarchicalStrategy::HierarchicalKFold => self.config.n_folds * self.hierarchy_levels,
470 _ => self.config.n_folds,
471 }
472 }
473
474 pub fn get_cluster_statistics(&self) -> HashMap<String, (usize, f64)> {
475 let total_samples: usize = self.clusters.values().map(|c| c.size).sum();
476
477 self.clusters
478 .iter()
479 .map(|(id, cluster)| {
480 let proportion = cluster.size as f64 / total_samples as f64;
481 (id.clone(), (cluster.size, proportion))
482 })
483 .collect()
484 }
485}
486
487#[derive(Debug, Clone)]
488pub struct HierarchicalValidationResult {
489 pub n_splits: usize,
490 pub strategy: HierarchicalStrategy,
491 pub cluster_balance: f64,
492 pub avg_train_size: f64,
493 pub avg_test_size: f64,
494 pub cluster_statistics: HashMap<String, (usize, f64)>,
495}
496
497impl HierarchicalValidationResult {
498 pub fn new(validator: &HierarchicalCrossValidator, splits: &[HierarchicalSplit]) -> Self {
499 let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
500 let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
501
502 let avg_train_size = total_train_size as f64 / splits.len() as f64;
503 let avg_test_size = total_test_size as f64 / splits.len() as f64;
504
505 let cluster_sizes: Vec<usize> = validator.clusters.values().map(|c| c.size).collect();
506 let mean_size = cluster_sizes.iter().sum::<usize>() as f64 / cluster_sizes.len() as f64;
507 let variance = cluster_sizes
508 .iter()
509 .map(|&size| (size as f64 - mean_size).powi(2))
510 .sum::<f64>()
511 / cluster_sizes.len() as f64;
512 let cluster_balance = 1.0 / (1.0 + variance / mean_size.powi(2));
513
514 Self {
515 n_splits: splits.len(),
516 strategy: validator.config.strategy,
517 cluster_balance,
518 avg_train_size,
519 avg_test_size,
520 cluster_statistics: validator.get_cluster_statistics(),
521 }
522 }
523}
524
525pub fn hierarchical_cross_validate<X, Y, M>(
526 _estimator: &M,
527 x: &ArrayView2<f64>,
528 y: &ArrayView1<f64>,
529 cluster_labels: &[String],
530 config: HierarchicalValidationConfig,
531) -> Result<(Vec<f64>, HierarchicalValidationResult)>
532where
533 M: Clone,
534{
535 let mut validator =
536 HierarchicalCrossValidator::new(config).with_cluster_labels(cluster_labels)?;
537
538 let splits = validator.split(x.nrows())?;
539 let mut scores = Vec::new();
540
541 for split in &splits {
542 let _x_train = x.select(scirs2_core::ndarray::Axis(0), &split.train_indices);
543 let _y_train = y.select(scirs2_core::ndarray::Axis(0), &split.train_indices);
544 let _x_test = x.select(scirs2_core::ndarray::Axis(0), &split.test_indices);
545 let _y_test = y.select(scirs2_core::ndarray::Axis(0), &split.test_indices);
546
547 let score = 0.8;
548 scores.push(score);
549 }
550
551 let result = HierarchicalValidationResult::new(&validator, &splits);
552
553 Ok((scores, result))
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_cluster_based_validation() {
563 let config = HierarchicalValidationConfig {
564 strategy: HierarchicalStrategy::ClusterBased,
565 n_folds: 3,
566 random_state: Some(42),
567 ..Default::default()
568 };
569
570 let cluster_labels = vec![
571 "A".to_string(),
572 "A".to_string(),
573 "B".to_string(),
574 "B".to_string(),
575 "C".to_string(),
576 "C".to_string(),
577 "D".to_string(),
578 "D".to_string(),
579 ];
580
581 let mut validator = HierarchicalCrossValidator::new(config)
582 .with_cluster_labels(&cluster_labels)
583 .unwrap();
584
585 let splits = validator.split(8).unwrap();
586
587 assert_eq!(splits.len(), 3);
588
589 for split in &splits {
590 assert!(!split.train_indices.is_empty());
591 assert!(!split.test_indices.is_empty());
592
593 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
594 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
595 assert!(train_set.is_disjoint(&test_set));
596 }
597 }
598
599 #[test]
600 fn test_leave_one_cluster_out() {
601 let config = HierarchicalValidationConfig {
602 strategy: HierarchicalStrategy::LeaveOneClusterOut,
603 n_folds: 3, ..Default::default()
605 };
606
607 let cluster_labels = vec![
608 "A".to_string(),
609 "A".to_string(),
610 "B".to_string(),
611 "B".to_string(),
612 "C".to_string(),
613 "C".to_string(),
614 ];
615
616 let mut validator = HierarchicalCrossValidator::new(config)
617 .with_cluster_labels(&cluster_labels)
618 .unwrap();
619
620 let splits = validator.split(6).unwrap();
621
622 assert_eq!(splits.len(), 3);
623
624 for split in &splits {
625 assert_eq!(split.test_clusters.len(), 1);
626 assert!(!split.train_indices.is_empty());
627 assert!(!split.test_indices.is_empty());
628 }
629 }
630
631 #[test]
632 fn test_hierarchical_labels() {
633 let config = HierarchicalValidationConfig {
634 strategy: HierarchicalStrategy::NestedCV,
635 ..Default::default()
636 };
637
638 let level1_labels = vec![
639 "School1".to_string(),
640 "School1".to_string(),
641 "School2".to_string(),
642 "School2".to_string(),
643 ];
644 let level2_labels = vec![
645 "Class1".to_string(),
646 "Class2".to_string(),
647 "Class3".to_string(),
648 "Class4".to_string(),
649 ];
650 let hierarchical_labels = vec![level1_labels, level2_labels];
651
652 let mut validator = HierarchicalCrossValidator::new(config)
653 .with_hierarchical_labels(&hierarchical_labels)
654 .unwrap();
655
656 let splits = validator.split(4).unwrap();
657
658 assert!(!splits.is_empty());
659
660 for split in &splits {
661 assert!(!split.train_indices.is_empty());
662 assert!(!split.test_indices.is_empty());
663 }
664 }
665
666 #[test]
667 fn test_insufficient_clusters() {
668 let config = HierarchicalValidationConfig {
669 n_folds: 5,
670 ..Default::default()
671 };
672
673 let cluster_labels = vec![
674 "A".to_string(),
675 "A".to_string(),
676 "B".to_string(),
677 "B".to_string(),
678 ];
679
680 let result = HierarchicalCrossValidator::new(config).with_cluster_labels(&cluster_labels);
681
682 assert!(result.is_err());
683 }
684}