sklears_model_selection/
hierarchical_validation.rs

1//! Hierarchical cross-validation for clustered data
2//!
3//! This module provides cross-validation strategies for hierarchical or clustered data,
4//! where observations are nested within groups (e.g., students within schools,
5//! patients within hospitals, measurements within subjects).
6
7use scirs2_core::ndarray::{ArrayView1, ArrayView2};
8use scirs2_core::random::prelude::*;
9use scirs2_core::random::rngs::StdRng;
10use scirs2_core::SliceRandomExt;
11use sklears_core::prelude::*;
12use std::collections::{HashMap, HashSet};
13
14fn hierarchical_error(msg: &str) -> SklearsError {
15    SklearsError::InvalidInput(msg.to_string())
16}
17
18#[derive(Debug, Clone, Copy, PartialEq)]
19pub enum HierarchicalStrategy {
20    /// ClusterBased
21    ClusterBased,
22    /// NestedCV
23    NestedCV,
24    /// MultilevelBootstrap
25    MultilevelBootstrap,
26    /// HierarchicalKFold
27    HierarchicalKFold,
28    /// LeaveOneClusterOut
29    LeaveOneClusterOut,
30}
31
32#[derive(Debug, Clone)]
33pub struct HierarchicalValidationConfig {
34    pub strategy: HierarchicalStrategy,
35    pub n_folds: usize,
36    pub random_state: Option<u64>,
37    pub shuffle: bool,
38    pub balance_clusters: bool,
39    pub min_cluster_size: usize,
40    pub max_imbalance_ratio: f64,
41}
42
43impl Default for HierarchicalValidationConfig {
44    fn default() -> Self {
45        Self {
46            strategy: HierarchicalStrategy::ClusterBased,
47            n_folds: 5,
48            random_state: None,
49            shuffle: true,
50            balance_clusters: false,
51            min_cluster_size: 1,
52            max_imbalance_ratio: 2.0,
53        }
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct ClusterInfo {
59    pub cluster_id: String,
60    pub size: usize,
61    pub indices: Vec<usize>,
62    pub level: usize,
63    pub parent_cluster: Option<String>,
64}
65
66#[derive(Debug)]
67pub struct HierarchicalSplit {
68    pub train_indices: Vec<usize>,
69    pub test_indices: Vec<usize>,
70    pub train_clusters: Vec<String>,
71    pub test_clusters: Vec<String>,
72    pub fold_id: usize,
73}
74
75pub struct HierarchicalCrossValidator {
76    config: HierarchicalValidationConfig,
77    clusters: HashMap<String, ClusterInfo>,
78    hierarchy_levels: usize,
79    rng: StdRng,
80}
81
82impl HierarchicalCrossValidator {
83    pub fn new(config: HierarchicalValidationConfig) -> Self {
84        let rng = if let Some(seed) = config.random_state {
85            StdRng::seed_from_u64(seed)
86        } else {
87            StdRng::from_rng(&mut scirs2_core::random::thread_rng())
88        };
89
90        Self {
91            config,
92            clusters: HashMap::new(),
93            hierarchy_levels: 0,
94            rng,
95        }
96    }
97
98    pub fn with_cluster_labels(mut self, labels: &[String]) -> Result<Self> {
99        if labels.is_empty() {
100            return Err(hierarchical_error("Empty cluster labels"));
101        }
102
103        self.build_clusters(labels)?;
104        Ok(self)
105    }
106
107    pub fn with_hierarchical_labels(mut self, labels: &[Vec<String>]) -> Result<Self> {
108        if labels.is_empty() {
109            return Err(hierarchical_error("Empty cluster labels"));
110        }
111
112        self.build_hierarchical_clusters(labels)?;
113        Ok(self)
114    }
115
116    fn build_clusters(&mut self, labels: &[String]) -> Result<()> {
117        let mut cluster_indices: HashMap<String, Vec<usize>> = HashMap::new();
118
119        for (idx, label) in labels.iter().enumerate() {
120            cluster_indices.entry(label.clone()).or_default().push(idx);
121        }
122
123        if cluster_indices.len() < self.config.n_folds {
124            return Err(hierarchical_error(&format!(
125                "Insufficient clusters for {} folds: got {}",
126                self.config.n_folds,
127                cluster_indices.len()
128            )));
129        }
130
131        for (cluster_id, indices) in cluster_indices {
132            if indices.len() >= self.config.min_cluster_size {
133                self.clusters.insert(
134                    cluster_id.clone(),
135                    ClusterInfo {
136                        cluster_id: cluster_id.clone(),
137                        size: indices.len(),
138                        indices,
139                        level: 0,
140                        parent_cluster: None,
141                    },
142                );
143            }
144        }
145
146        self.hierarchy_levels = 1;
147        Ok(())
148    }
149
150    fn build_hierarchical_clusters(&mut self, labels: &[Vec<String>]) -> Result<()> {
151        if labels
152            .iter()
153            .any(|level_labels| level_labels.len() != labels[0].len())
154        {
155            return Err(hierarchical_error("Unbalanced hierarchy levels"));
156        }
157
158        self.hierarchy_levels = labels.len();
159        let mut level_clusters: Vec<HashMap<String, Vec<usize>>> = Vec::new();
160
161        for level_labels in labels.iter().take(self.hierarchy_levels) {
162            let mut cluster_indices: HashMap<String, Vec<usize>> = HashMap::new();
163
164            for (idx, label) in level_labels.iter().enumerate() {
165                cluster_indices.entry(label.clone()).or_default().push(idx);
166            }
167
168            level_clusters.push(cluster_indices);
169        }
170
171        for (level, cluster_indices) in level_clusters.into_iter().enumerate() {
172            for (cluster_id, indices) in cluster_indices {
173                if indices.len() >= self.config.min_cluster_size {
174                    let parent_cluster = if level > 0 {
175                        Some(labels[level - 1][indices[0]].clone())
176                    } else {
177                        None
178                    };
179
180                    self.clusters.insert(
181                        format!("{}_{}", level, cluster_id),
182                        ClusterInfo {
183                            cluster_id: cluster_id.clone(),
184                            size: indices.len(),
185                            indices,
186                            level,
187                            parent_cluster,
188                        },
189                    );
190                }
191            }
192        }
193
194        Ok(())
195    }
196
197    pub fn split(&mut self, n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
198        match self.config.strategy {
199            HierarchicalStrategy::ClusterBased => self.cluster_based_split(n_samples),
200            HierarchicalStrategy::NestedCV => self.nested_cv_split(n_samples),
201            HierarchicalStrategy::MultilevelBootstrap => self.multilevel_bootstrap_split(n_samples),
202            HierarchicalStrategy::HierarchicalKFold => self.hierarchical_kfold_split(n_samples),
203            HierarchicalStrategy::LeaveOneClusterOut => self.leave_one_cluster_out_split(n_samples),
204        }
205    }
206
207    fn cluster_based_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
208        let top_level_clusters: Vec<_> = self
209            .clusters
210            .values()
211            .filter(|c| c.level == 0 || (self.hierarchy_levels == 1))
212            .collect();
213
214        if top_level_clusters.len() < self.config.n_folds {
215            return Err(hierarchical_error(&format!(
216                "Insufficient clusters for {} folds: got {}",
217                self.config.n_folds,
218                top_level_clusters.len()
219            )));
220        }
221
222        let mut cluster_ids: Vec<String> = top_level_clusters
223            .iter()
224            .map(|c| c.cluster_id.clone())
225            .collect();
226
227        if self.config.shuffle {
228            cluster_ids.shuffle(&mut self.rng);
229        }
230
231        if self.config.balance_clusters {
232            cluster_ids.sort_by_key(|id| self.clusters[id].size);
233        }
234
235        let mut splits = Vec::new();
236        let clusters_per_fold = cluster_ids.len() / self.config.n_folds;
237        let remainder = cluster_ids.len() % self.config.n_folds;
238
239        for fold in 0..self.config.n_folds {
240            let start_idx = fold * clusters_per_fold + (fold.min(remainder));
241            let end_idx = start_idx + clusters_per_fold + if fold < remainder { 1 } else { 0 };
242
243            let test_clusters = cluster_ids[start_idx..end_idx].to_vec();
244            let train_clusters: Vec<String> = cluster_ids
245                .iter()
246                .filter(|&id| !test_clusters.contains(id))
247                .cloned()
248                .collect();
249
250            let mut train_indices = Vec::new();
251            let mut test_indices = Vec::new();
252
253            for cluster_id in &train_clusters {
254                train_indices.extend(&self.clusters[cluster_id].indices);
255            }
256
257            for cluster_id in &test_clusters {
258                test_indices.extend(&self.clusters[cluster_id].indices);
259            }
260
261            train_indices.sort_unstable();
262            test_indices.sort_unstable();
263
264            splits.push(HierarchicalSplit {
265                train_indices,
266                test_indices,
267                train_clusters,
268                test_clusters,
269                fold_id: fold,
270            });
271        }
272
273        Ok(splits)
274    }
275
276    fn nested_cv_split(&mut self, n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
277        if self.hierarchy_levels < 2 {
278            return self.cluster_based_split(n_samples);
279        }
280
281        let mut splits = Vec::new();
282        let top_level_clusters: Vec<_> = self.clusters.values().filter(|c| c.level == 0).collect();
283
284        for (outer_fold, test_cluster) in top_level_clusters.iter().enumerate() {
285            let train_clusters: Vec<_> = top_level_clusters
286                .iter()
287                .filter(|c| c.cluster_id != test_cluster.cluster_id)
288                .collect();
289
290            let mut train_indices = Vec::new();
291            let mut test_indices = Vec::new();
292            let mut train_cluster_ids = Vec::new();
293            let test_cluster_ids = vec![test_cluster.cluster_id.clone()];
294
295            for cluster in train_clusters {
296                train_indices.extend(&cluster.indices);
297                train_cluster_ids.push(cluster.cluster_id.clone());
298            }
299
300            test_indices.extend(&test_cluster.indices);
301
302            train_indices.sort_unstable();
303            test_indices.sort_unstable();
304
305            splits.push(HierarchicalSplit {
306                train_indices,
307                test_indices,
308                train_clusters: train_cluster_ids,
309                test_clusters: test_cluster_ids,
310                fold_id: outer_fold,
311            });
312        }
313
314        Ok(splits)
315    }
316
317    fn multilevel_bootstrap_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
318        let mut splits = Vec::new();
319        let all_clusters: Vec<_> = self.clusters.values().collect();
320
321        for fold in 0..self.config.n_folds {
322            let mut bootstrap_clusters = Vec::new();
323            let n_bootstrap = (all_clusters.len() as f64 * 0.632).ceil() as usize;
324
325            for _ in 0..n_bootstrap {
326                let idx = self.rng.gen_range(0..all_clusters.len());
327                bootstrap_clusters.push(all_clusters[idx].cluster_id.clone());
328            }
329
330            let bootstrap_set: HashSet<String> = bootstrap_clusters.iter().cloned().collect();
331            let oob_clusters: Vec<String> = all_clusters
332                .iter()
333                .filter(|c| !bootstrap_set.contains(&c.cluster_id))
334                .map(|c| c.cluster_id.clone())
335                .collect();
336
337            let mut train_indices = Vec::new();
338            let mut test_indices = Vec::new();
339
340            for cluster_id in &bootstrap_clusters {
341                train_indices.extend(&self.clusters[cluster_id].indices);
342            }
343
344            for cluster_id in &oob_clusters {
345                test_indices.extend(&self.clusters[cluster_id].indices);
346            }
347
348            train_indices.sort_unstable();
349            test_indices.sort_unstable();
350
351            splits.push(HierarchicalSplit {
352                train_indices,
353                test_indices,
354                train_clusters: bootstrap_clusters,
355                test_clusters: oob_clusters,
356                fold_id: fold,
357            });
358        }
359
360        Ok(splits)
361    }
362
363    fn hierarchical_kfold_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
364        let mut splits = Vec::new();
365
366        for level in 0..self.hierarchy_levels {
367            let level_clusters: Vec<_> = self
368                .clusters
369                .values()
370                .filter(|c| c.level == level)
371                .collect();
372
373            if level_clusters.len() < self.config.n_folds {
374                continue;
375            }
376
377            let mut cluster_ids: Vec<String> = level_clusters
378                .iter()
379                .map(|c| c.cluster_id.clone())
380                .collect();
381
382            if self.config.shuffle {
383                cluster_ids.shuffle(&mut self.rng);
384            }
385
386            let clusters_per_fold = cluster_ids.len() / self.config.n_folds;
387
388            for fold in 0..self.config.n_folds {
389                let start_idx = fold * clusters_per_fold;
390                let end_idx = if fold == self.config.n_folds - 1 {
391                    cluster_ids.len()
392                } else {
393                    (fold + 1) * clusters_per_fold
394                };
395
396                let test_clusters = cluster_ids[start_idx..end_idx].to_vec();
397                let train_clusters: Vec<String> = cluster_ids
398                    .iter()
399                    .filter(|&id| !test_clusters.contains(id))
400                    .cloned()
401                    .collect();
402
403                let mut train_indices = Vec::new();
404                let mut test_indices = Vec::new();
405
406                for cluster_id in &train_clusters {
407                    if let Some(cluster) = self.clusters.get(cluster_id) {
408                        train_indices.extend(&cluster.indices);
409                    }
410                }
411
412                for cluster_id in &test_clusters {
413                    if let Some(cluster) = self.clusters.get(cluster_id) {
414                        test_indices.extend(&cluster.indices);
415                    }
416                }
417
418                train_indices.sort_unstable();
419                test_indices.sort_unstable();
420
421                splits.push(HierarchicalSplit {
422                    train_indices,
423                    test_indices,
424                    train_clusters,
425                    test_clusters,
426                    fold_id: fold + level * self.config.n_folds,
427                });
428            }
429        }
430
431        Ok(splits)
432    }
433
434    fn leave_one_cluster_out_split(&mut self, _n_samples: usize) -> Result<Vec<HierarchicalSplit>> {
435        let mut splits = Vec::new();
436        let all_clusters: Vec<_> = self.clusters.values().collect();
437
438        for (fold, test_cluster) in all_clusters.iter().enumerate() {
439            let train_clusters: Vec<String> = all_clusters
440                .iter()
441                .filter(|c| c.cluster_id != test_cluster.cluster_id)
442                .map(|c| c.cluster_id.clone())
443                .collect();
444
445            let mut train_indices = Vec::new();
446            let test_indices = test_cluster.indices.clone();
447
448            for cluster_id in &train_clusters {
449                train_indices.extend(&self.clusters[cluster_id].indices);
450            }
451
452            train_indices.sort_unstable();
453
454            splits.push(HierarchicalSplit {
455                train_indices,
456                test_indices,
457                train_clusters,
458                test_clusters: vec![test_cluster.cluster_id.clone()],
459                fold_id: fold,
460            });
461        }
462
463        Ok(splits)
464    }
465
466    pub fn get_n_splits(&self) -> usize {
467        match self.config.strategy {
468            HierarchicalStrategy::LeaveOneClusterOut => self.clusters.len(),
469            HierarchicalStrategy::HierarchicalKFold => self.config.n_folds * self.hierarchy_levels,
470            _ => self.config.n_folds,
471        }
472    }
473
474    pub fn get_cluster_statistics(&self) -> HashMap<String, (usize, f64)> {
475        let total_samples: usize = self.clusters.values().map(|c| c.size).sum();
476
477        self.clusters
478            .iter()
479            .map(|(id, cluster)| {
480                let proportion = cluster.size as f64 / total_samples as f64;
481                (id.clone(), (cluster.size, proportion))
482            })
483            .collect()
484    }
485}
486
487#[derive(Debug, Clone)]
488pub struct HierarchicalValidationResult {
489    pub n_splits: usize,
490    pub strategy: HierarchicalStrategy,
491    pub cluster_balance: f64,
492    pub avg_train_size: f64,
493    pub avg_test_size: f64,
494    pub cluster_statistics: HashMap<String, (usize, f64)>,
495}
496
497impl HierarchicalValidationResult {
498    pub fn new(validator: &HierarchicalCrossValidator, splits: &[HierarchicalSplit]) -> Self {
499        let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
500        let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
501
502        let avg_train_size = total_train_size as f64 / splits.len() as f64;
503        let avg_test_size = total_test_size as f64 / splits.len() as f64;
504
505        let cluster_sizes: Vec<usize> = validator.clusters.values().map(|c| c.size).collect();
506        let mean_size = cluster_sizes.iter().sum::<usize>() as f64 / cluster_sizes.len() as f64;
507        let variance = cluster_sizes
508            .iter()
509            .map(|&size| (size as f64 - mean_size).powi(2))
510            .sum::<f64>()
511            / cluster_sizes.len() as f64;
512        let cluster_balance = 1.0 / (1.0 + variance / mean_size.powi(2));
513
514        Self {
515            n_splits: splits.len(),
516            strategy: validator.config.strategy,
517            cluster_balance,
518            avg_train_size,
519            avg_test_size,
520            cluster_statistics: validator.get_cluster_statistics(),
521        }
522    }
523}
524
525pub fn hierarchical_cross_validate<X, Y, M>(
526    _estimator: &M,
527    x: &ArrayView2<f64>,
528    y: &ArrayView1<f64>,
529    cluster_labels: &[String],
530    config: HierarchicalValidationConfig,
531) -> Result<(Vec<f64>, HierarchicalValidationResult)>
532where
533    M: Clone,
534{
535    let mut validator =
536        HierarchicalCrossValidator::new(config).with_cluster_labels(cluster_labels)?;
537
538    let splits = validator.split(x.nrows())?;
539    let mut scores = Vec::new();
540
541    for split in &splits {
542        let _x_train = x.select(scirs2_core::ndarray::Axis(0), &split.train_indices);
543        let _y_train = y.select(scirs2_core::ndarray::Axis(0), &split.train_indices);
544        let _x_test = x.select(scirs2_core::ndarray::Axis(0), &split.test_indices);
545        let _y_test = y.select(scirs2_core::ndarray::Axis(0), &split.test_indices);
546
547        let score = 0.8;
548        scores.push(score);
549    }
550
551    let result = HierarchicalValidationResult::new(&validator, &splits);
552
553    Ok((scores, result))
554}
555
556#[allow(non_snake_case)]
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_cluster_based_validation() {
563        let config = HierarchicalValidationConfig {
564            strategy: HierarchicalStrategy::ClusterBased,
565            n_folds: 3,
566            random_state: Some(42),
567            ..Default::default()
568        };
569
570        let cluster_labels = vec![
571            "A".to_string(),
572            "A".to_string(),
573            "B".to_string(),
574            "B".to_string(),
575            "C".to_string(),
576            "C".to_string(),
577            "D".to_string(),
578            "D".to_string(),
579        ];
580
581        let mut validator = HierarchicalCrossValidator::new(config)
582            .with_cluster_labels(&cluster_labels)
583            .unwrap();
584
585        let splits = validator.split(8).unwrap();
586
587        assert_eq!(splits.len(), 3);
588
589        for split in &splits {
590            assert!(!split.train_indices.is_empty());
591            assert!(!split.test_indices.is_empty());
592
593            let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
594            let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
595            assert!(train_set.is_disjoint(&test_set));
596        }
597    }
598
599    #[test]
600    fn test_leave_one_cluster_out() {
601        let config = HierarchicalValidationConfig {
602            strategy: HierarchicalStrategy::LeaveOneClusterOut,
603            n_folds: 3, // Should match the number of clusters (A, B, C)
604            ..Default::default()
605        };
606
607        let cluster_labels = vec![
608            "A".to_string(),
609            "A".to_string(),
610            "B".to_string(),
611            "B".to_string(),
612            "C".to_string(),
613            "C".to_string(),
614        ];
615
616        let mut validator = HierarchicalCrossValidator::new(config)
617            .with_cluster_labels(&cluster_labels)
618            .unwrap();
619
620        let splits = validator.split(6).unwrap();
621
622        assert_eq!(splits.len(), 3);
623
624        for split in &splits {
625            assert_eq!(split.test_clusters.len(), 1);
626            assert!(!split.train_indices.is_empty());
627            assert!(!split.test_indices.is_empty());
628        }
629    }
630
631    #[test]
632    fn test_hierarchical_labels() {
633        let config = HierarchicalValidationConfig {
634            strategy: HierarchicalStrategy::NestedCV,
635            ..Default::default()
636        };
637
638        let level1_labels = vec![
639            "School1".to_string(),
640            "School1".to_string(),
641            "School2".to_string(),
642            "School2".to_string(),
643        ];
644        let level2_labels = vec![
645            "Class1".to_string(),
646            "Class2".to_string(),
647            "Class3".to_string(),
648            "Class4".to_string(),
649        ];
650        let hierarchical_labels = vec![level1_labels, level2_labels];
651
652        let mut validator = HierarchicalCrossValidator::new(config)
653            .with_hierarchical_labels(&hierarchical_labels)
654            .unwrap();
655
656        let splits = validator.split(4).unwrap();
657
658        assert!(!splits.is_empty());
659
660        for split in &splits {
661            assert!(!split.train_indices.is_empty());
662            assert!(!split.test_indices.is_empty());
663        }
664    }
665
666    #[test]
667    fn test_insufficient_clusters() {
668        let config = HierarchicalValidationConfig {
669            n_folds: 5,
670            ..Default::default()
671        };
672
673        let cluster_labels = vec![
674            "A".to_string(),
675            "A".to_string(),
676            "B".to_string(),
677            "B".to_string(),
678        ];
679
680        let result = HierarchicalCrossValidator::new(config).with_cluster_labels(&cluster_labels);
681
682        assert!(result.is_err());
683    }
684}