sklears_model_selection/
multilabel_validation.rs

1//! Multi-label cross-validation strategies
2//!
3//! This module provides cross-validation strategies specifically designed for multi-label
4//! classification problems, where each sample can belong to multiple classes simultaneously.
5
6use scirs2_core::ndarray::{ArrayView2, Axis};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::SliceRandomExt;
10use sklears_core::prelude::*;
11use std::collections::{HashMap, HashSet};
12
13fn multilabel_error(msg: &str) -> SklearsError {
14    SklearsError::InvalidInput(msg.to_string())
15}
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum MultiLabelStrategy {
19    /// IterativeStratification
20    IterativeStratification,
21    /// LabelPowerset
22    LabelPowerset,
23    /// MultilabelKFold
24    MultilabelKFold,
25    /// LabelDistributionStratification
26    LabelDistributionStratification,
27    /// MinorityClassStratification
28    MinorityClassStratification,
29}
30
31#[derive(Debug, Clone)]
32pub struct MultiLabelValidationConfig {
33    pub strategy: MultiLabelStrategy,
34    pub n_folds: usize,
35    pub random_state: Option<u64>,
36    pub shuffle: bool,
37    pub min_samples_per_label: usize,
38    pub balance_ratio: f64,
39    pub max_label_combinations: Option<usize>,
40}
41
42impl Default for MultiLabelValidationConfig {
43    fn default() -> Self {
44        Self {
45            strategy: MultiLabelStrategy::IterativeStratification,
46            n_folds: 5,
47            random_state: None,
48            shuffle: true,
49            min_samples_per_label: 2,
50            balance_ratio: 0.1,
51            max_label_combinations: Some(1000),
52        }
53    }
54}
55
56#[derive(Debug, Clone)]
57pub struct LabelStatistics {
58    pub label_frequencies: Vec<usize>,
59    pub label_proportions: Vec<f64>,
60    pub label_combinations: HashMap<Vec<usize>, usize>,
61    pub mean_labels_per_sample: f64,
62    pub label_cardinality: f64,
63    pub label_density: f64,
64}
65
66#[derive(Debug)]
67pub struct MultiLabelSplit {
68    pub train_indices: Vec<usize>,
69    pub test_indices: Vec<usize>,
70    pub fold_id: usize,
71    pub train_label_distribution: Vec<f64>,
72    pub test_label_distribution: Vec<f64>,
73}
74
75pub struct MultiLabelCrossValidator {
76    config: MultiLabelValidationConfig,
77    n_labels: usize,
78    label_stats: Option<LabelStatistics>,
79    rng: StdRng,
80}
81
82impl MultiLabelCrossValidator {
83    pub fn new(config: MultiLabelValidationConfig) -> Self {
84        let rng = if let Some(seed) = config.random_state {
85            StdRng::seed_from_u64(seed)
86        } else {
87            StdRng::from_rng(&mut scirs2_core::random::thread_rng())
88        };
89
90        Self {
91            config,
92            n_labels: 0,
93            label_stats: None,
94            rng,
95        }
96    }
97
98    pub fn fit(&mut self, y: &ArrayView2<i32>) -> Result<()> {
99        if y.is_empty() {
100            return Err(multilabel_error("Empty label matrix"));
101        }
102
103        self.n_labels = y.ncols();
104        self.label_stats = Some(self.compute_label_statistics(y)?);
105        Ok(())
106    }
107
108    fn compute_label_statistics(&self, y: &ArrayView2<i32>) -> Result<LabelStatistics> {
109        let n_samples = y.nrows();
110        let n_labels = y.ncols();
111
112        let mut label_frequencies = vec![0; n_labels];
113        let mut label_combinations: HashMap<Vec<usize>, usize> = HashMap::new();
114        let mut total_labels = 0;
115
116        for sample_idx in 0..n_samples {
117            let mut active_labels = Vec::new();
118
119            for label_idx in 0..n_labels {
120                if y[[sample_idx, label_idx]] == 1 {
121                    label_frequencies[label_idx] += 1;
122                    active_labels.push(label_idx);
123                    total_labels += 1;
124                }
125            }
126
127            if !active_labels.is_empty() {
128                active_labels.sort();
129                *label_combinations.entry(active_labels).or_insert(0) += 1;
130            }
131        }
132
133        if total_labels == 0 {
134            return Err(multilabel_error("No positive labels found"));
135        }
136
137        let label_proportions: Vec<f64> = label_frequencies
138            .iter()
139            .map(|&freq| freq as f64 / n_samples as f64)
140            .collect();
141
142        let mean_labels_per_sample = total_labels as f64 / n_samples as f64;
143        let label_cardinality = mean_labels_per_sample;
144        let label_density = mean_labels_per_sample / n_labels as f64;
145
146        Ok(LabelStatistics {
147            label_frequencies,
148            label_proportions,
149            label_combinations,
150            mean_labels_per_sample,
151            label_cardinality,
152            label_density,
153        })
154    }
155
156    pub fn split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
157        if self.label_stats.is_none() {
158            self.fit(y)?;
159        }
160
161        match self.config.strategy {
162            MultiLabelStrategy::IterativeStratification => self.iterative_stratification_split(y),
163            MultiLabelStrategy::LabelPowerset => self.label_powerset_split(y),
164            MultiLabelStrategy::MultilabelKFold => self.multilabel_kfold_split(y),
165            MultiLabelStrategy::LabelDistributionStratification => {
166                self.label_distribution_stratification_split(y)
167            }
168            MultiLabelStrategy::MinorityClassStratification => {
169                self.minority_class_stratification_split(y)
170            }
171        }
172    }
173
174    fn iterative_stratification_split(
175        &mut self,
176        y: &ArrayView2<i32>,
177    ) -> Result<Vec<MultiLabelSplit>> {
178        let n_samples = y.nrows();
179        let n_labels = y.ncols();
180
181        if n_samples < self.config.n_folds {
182            return Err(multilabel_error(&format!(
183                "Insufficient samples for {} folds: got {}",
184                self.config.n_folds, n_samples
185            )));
186        }
187
188        let mut sample_indices: Vec<usize> = (0..n_samples).collect();
189        if self.config.shuffle {
190            sample_indices.shuffle(&mut self.rng);
191        }
192
193        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
194        let mut fold_label_counts: Vec<Vec<usize>> = vec![vec![0; n_labels]; self.config.n_folds];
195
196        let target_samples_per_fold = n_samples / self.config.n_folds;
197        let remaining_samples = n_samples % self.config.n_folds;
198
199        let label_stats = self.label_stats.as_ref().unwrap();
200        let mut remaining_label_counts = label_stats.label_frequencies.clone();
201        let mut remaining_samples_set: HashSet<usize> = sample_indices.iter().cloned().collect();
202
203        while !remaining_samples_set.is_empty() {
204            let mut best_fold = 0;
205            let mut best_score = f64::NEG_INFINITY;
206            let mut best_sample = *remaining_samples_set.iter().next().unwrap();
207
208            for &sample_idx in &remaining_samples_set {
209                let sample_labels: Vec<usize> = (0..n_labels)
210                    .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
211                    .collect();
212
213                for fold_idx in 0..self.config.n_folds {
214                    let current_fold_size = folds[fold_idx].len();
215                    let target_fold_size =
216                        target_samples_per_fold + if fold_idx < remaining_samples { 1 } else { 0 };
217
218                    if current_fold_size >= target_fold_size {
219                        continue;
220                    }
221
222                    let mut score = 0.0;
223                    for &label_idx in &sample_labels {
224                        if remaining_label_counts[label_idx] > 0 {
225                            let current_proportion = fold_label_counts[fold_idx][label_idx] as f64
226                                / (current_fold_size + 1) as f64;
227                            let target_proportion = label_stats.label_proportions[label_idx];
228                            score += 1.0 / (1.0 + (current_proportion - target_proportion).abs());
229                        }
230                    }
231
232                    if score > best_score {
233                        best_score = score;
234                        best_fold = fold_idx;
235                        best_sample = sample_idx;
236                    }
237                }
238            }
239
240            folds[best_fold].push(best_sample);
241            remaining_samples_set.remove(&best_sample);
242
243            for label_idx in 0..n_labels {
244                if y[[best_sample, label_idx]] == 1 {
245                    fold_label_counts[best_fold][label_idx] += 1;
246                    remaining_label_counts[label_idx] -= 1;
247                }
248            }
249        }
250
251        let mut splits = Vec::new();
252        for test_fold in 0..self.config.n_folds {
253            let test_indices = folds[test_fold].clone();
254            let mut train_indices = Vec::new();
255
256            for fold_idx in 0..self.config.n_folds {
257                if fold_idx != test_fold {
258                    train_indices.extend(&folds[fold_idx]);
259                }
260            }
261
262            let train_label_distribution = self.compute_label_distribution(y, &train_indices);
263            let test_label_distribution = self.compute_label_distribution(y, &test_indices);
264
265            splits.push(MultiLabelSplit {
266                train_indices,
267                test_indices,
268                fold_id: test_fold,
269                train_label_distribution,
270                test_label_distribution,
271            });
272        }
273
274        Ok(splits)
275    }
276
277    fn label_powerset_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
278        let n_samples = y.nrows();
279        let _label_stats = self.label_stats.as_ref().unwrap();
280
281        let mut powerset_to_samples: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
282
283        for sample_idx in 0..n_samples {
284            let mut active_labels: Vec<usize> = (0..self.n_labels)
285                .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
286                .collect();
287
288            if active_labels.is_empty() {
289                active_labels = vec![];
290            } else {
291                active_labels.sort();
292            }
293
294            powerset_to_samples
295                .entry(active_labels)
296                .or_default()
297                .push(sample_idx);
298        }
299
300        if let Some(max_combinations) = self.config.max_label_combinations {
301            if powerset_to_samples.len() > max_combinations {
302                let mut sorted_combinations: Vec<_> = powerset_to_samples.iter().collect();
303                sorted_combinations.sort_by_key(|(_, samples)| std::cmp::Reverse(samples.len()));
304
305                let mut new_powerset = HashMap::new();
306                for (combination, samples) in sorted_combinations.into_iter().take(max_combinations)
307                {
308                    new_powerset.insert(combination.clone(), samples.clone());
309                }
310                powerset_to_samples = new_powerset;
311            }
312        }
313
314        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
315
316        for (_, mut samples) in powerset_to_samples {
317            if self.config.shuffle {
318                samples.shuffle(&mut self.rng);
319            }
320
321            for (idx, sample) in samples.into_iter().enumerate() {
322                let fold_idx = idx % self.config.n_folds;
323                folds[fold_idx].push(sample);
324            }
325        }
326
327        let mut splits = Vec::new();
328        for test_fold in 0..self.config.n_folds {
329            let test_indices = folds[test_fold].clone();
330            let mut train_indices = Vec::new();
331
332            for fold_idx in 0..self.config.n_folds {
333                if fold_idx != test_fold {
334                    train_indices.extend(&folds[fold_idx]);
335                }
336            }
337
338            let train_label_distribution = self.compute_label_distribution(y, &train_indices);
339            let test_label_distribution = self.compute_label_distribution(y, &test_indices);
340
341            splits.push(MultiLabelSplit {
342                train_indices,
343                test_indices,
344                fold_id: test_fold,
345                train_label_distribution,
346                test_label_distribution,
347            });
348        }
349
350        Ok(splits)
351    }
352
353    fn multilabel_kfold_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
354        let n_samples = y.nrows();
355
356        if n_samples < self.config.n_folds {
357            return Err(multilabel_error(&format!(
358                "Insufficient samples for {} folds: got {}",
359                self.config.n_folds, n_samples
360            )));
361        }
362
363        let mut sample_indices: Vec<usize> = (0..n_samples).collect();
364        if self.config.shuffle {
365            sample_indices.shuffle(&mut self.rng);
366        }
367
368        let samples_per_fold = n_samples / self.config.n_folds;
369        let remainder = n_samples % self.config.n_folds;
370
371        let mut splits = Vec::new();
372        let mut start_idx = 0;
373
374        for fold in 0..self.config.n_folds {
375            let fold_size = samples_per_fold + if fold < remainder { 1 } else { 0 };
376            let test_indices = sample_indices[start_idx..start_idx + fold_size].to_vec();
377
378            let mut train_indices = Vec::new();
379            train_indices.extend(&sample_indices[..start_idx]);
380            train_indices.extend(&sample_indices[start_idx + fold_size..]);
381
382            let train_label_distribution = self.compute_label_distribution(y, &train_indices);
383            let test_label_distribution = self.compute_label_distribution(y, &test_indices);
384
385            splits.push(MultiLabelSplit {
386                train_indices,
387                test_indices,
388                fold_id: fold,
389                train_label_distribution,
390                test_label_distribution,
391            });
392
393            start_idx += fold_size;
394        }
395
396        Ok(splits)
397    }
398
399    fn label_distribution_stratification_split(
400        &mut self,
401        y: &ArrayView2<i32>,
402    ) -> Result<Vec<MultiLabelSplit>> {
403        let n_samples = y.nrows();
404        let label_stats = self.label_stats.as_ref().unwrap();
405
406        let mut samples_with_weights: Vec<(usize, f64)> = Vec::new();
407
408        for sample_idx in 0..n_samples {
409            let mut weight = 0.0;
410            let mut label_count = 0;
411
412            for label_idx in 0..self.n_labels {
413                if y[[sample_idx, label_idx]] == 1 {
414                    let label_frequency = label_stats.label_frequencies[label_idx];
415                    weight += 1.0 / (label_frequency as f64).sqrt();
416                    label_count += 1;
417                }
418            }
419
420            if label_count > 0 {
421                weight /= label_count as f64;
422            }
423
424            samples_with_weights.push((sample_idx, weight));
425        }
426
427        samples_with_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
428
429        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
430
431        for (sample_idx, _) in samples_with_weights {
432            let fold_idx = folds
433                .iter()
434                .enumerate()
435                .min_by_key(|(_, fold)| fold.len())
436                .unwrap()
437                .0;
438            folds[fold_idx].push(sample_idx);
439        }
440
441        let mut splits = Vec::new();
442        for test_fold in 0..self.config.n_folds {
443            let test_indices = folds[test_fold].clone();
444            let mut train_indices = Vec::new();
445
446            for fold_idx in 0..self.config.n_folds {
447                if fold_idx != test_fold {
448                    train_indices.extend(&folds[fold_idx]);
449                }
450            }
451
452            let train_label_distribution = self.compute_label_distribution(y, &train_indices);
453            let test_label_distribution = self.compute_label_distribution(y, &test_indices);
454
455            splits.push(MultiLabelSplit {
456                train_indices,
457                test_indices,
458                fold_id: test_fold,
459                train_label_distribution,
460                test_label_distribution,
461            });
462        }
463
464        Ok(splits)
465    }
466
467    fn minority_class_stratification_split(
468        &mut self,
469        y: &ArrayView2<i32>,
470    ) -> Result<Vec<MultiLabelSplit>> {
471        let n_samples = y.nrows();
472        let label_stats = self.label_stats.as_ref().unwrap();
473
474        let minority_threshold = (n_samples as f64 * self.config.balance_ratio) as usize;
475        let minority_labels: Vec<usize> = label_stats
476            .label_frequencies
477            .iter()
478            .enumerate()
479            .filter(|(_, &freq)| {
480                freq <= minority_threshold && freq >= self.config.min_samples_per_label
481            })
482            .map(|(idx, _)| idx)
483            .collect();
484
485        if minority_labels.is_empty() {
486            return self.multilabel_kfold_split(y);
487        }
488
489        let mut samples_by_minority: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
490
491        for sample_idx in 0..n_samples {
492            let sample_minority_labels: Vec<usize> = minority_labels
493                .iter()
494                .filter(|&&label_idx| y[[sample_idx, label_idx]] == 1)
495                .cloned()
496                .collect();
497
498            samples_by_minority
499                .entry(sample_minority_labels)
500                .or_default()
501                .push(sample_idx);
502        }
503
504        let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
505
506        for (_, mut samples) in samples_by_minority {
507            if self.config.shuffle {
508                samples.shuffle(&mut self.rng);
509            }
510
511            for (idx, sample) in samples.into_iter().enumerate() {
512                let fold_idx = idx % self.config.n_folds;
513                folds[fold_idx].push(sample);
514            }
515        }
516
517        let mut splits = Vec::new();
518        for test_fold in 0..self.config.n_folds {
519            let test_indices = folds[test_fold].clone();
520            let mut train_indices = Vec::new();
521
522            for fold_idx in 0..self.config.n_folds {
523                if fold_idx != test_fold {
524                    train_indices.extend(&folds[fold_idx]);
525                }
526            }
527
528            let train_label_distribution = self.compute_label_distribution(y, &train_indices);
529            let test_label_distribution = self.compute_label_distribution(y, &test_indices);
530
531            splits.push(MultiLabelSplit {
532                train_indices,
533                test_indices,
534                fold_id: test_fold,
535                train_label_distribution,
536                test_label_distribution,
537            });
538        }
539
540        Ok(splits)
541    }
542
543    fn compute_label_distribution(&self, y: &ArrayView2<i32>, indices: &[usize]) -> Vec<f64> {
544        let mut label_counts = vec![0; self.n_labels];
545
546        for &idx in indices {
547            for label_idx in 0..self.n_labels {
548                if y[[idx, label_idx]] == 1 {
549                    label_counts[label_idx] += 1;
550                }
551            }
552        }
553
554        label_counts
555            .into_iter()
556            .map(|count| count as f64 / indices.len() as f64)
557            .collect()
558    }
559
560    pub fn get_n_splits(&self) -> usize {
561        self.config.n_folds
562    }
563
564    pub fn get_label_statistics(&self) -> Option<&LabelStatistics> {
565        self.label_stats.as_ref()
566    }
567}
568
569#[derive(Debug, Clone)]
570pub struct MultiLabelValidationResult {
571    pub n_splits: usize,
572    pub strategy: MultiLabelStrategy,
573    pub label_cardinality: f64,
574    pub label_density: f64,
575    pub label_distribution_variance: f64,
576    pub avg_train_size: f64,
577    pub avg_test_size: f64,
578}
579
580impl MultiLabelValidationResult {
581    pub fn new(validator: &MultiLabelCrossValidator, splits: &[MultiLabelSplit]) -> Self {
582        let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
583        let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
584
585        let avg_train_size = total_train_size as f64 / splits.len() as f64;
586        let avg_test_size = total_test_size as f64 / splits.len() as f64;
587
588        let label_stats = validator.get_label_statistics().unwrap();
589
590        let all_distributions: Vec<&Vec<f64>> = splits
591            .iter()
592            .flat_map(|s| vec![&s.train_label_distribution, &s.test_label_distribution])
593            .collect();
594
595        let mut total_variance = 0.0;
596        for label_idx in 0..label_stats.label_proportions.len() {
597            let target_proportion = label_stats.label_proportions[label_idx];
598            let variance: f64 = all_distributions
599                .iter()
600                .map(|dist| (dist[label_idx] - target_proportion).powi(2))
601                .sum::<f64>()
602                / all_distributions.len() as f64;
603            total_variance += variance;
604        }
605
606        Self {
607            n_splits: splits.len(),
608            strategy: validator.config.strategy,
609            label_cardinality: label_stats.label_cardinality,
610            label_density: label_stats.label_density,
611            label_distribution_variance: total_variance,
612            avg_train_size,
613            avg_test_size,
614        }
615    }
616}
617
618pub fn multilabel_cross_validate<X, Y, M>(
619    _estimator: &M,
620    x: &ArrayView2<f64>,
621    y: &ArrayView2<i32>,
622    config: MultiLabelValidationConfig,
623) -> Result<(Vec<f64>, MultiLabelValidationResult)>
624where
625    M: Clone,
626{
627    let mut validator = MultiLabelCrossValidator::new(config);
628    validator.fit(y)?;
629
630    let splits = validator.split(y)?;
631    let mut scores = Vec::new();
632
633    for split in &splits {
634        let _x_train = x.select(Axis(0), &split.train_indices);
635        let _y_train = y.select(Axis(0), &split.train_indices);
636        let _x_test = x.select(Axis(0), &split.test_indices);
637        let _y_test = y.select(Axis(0), &split.test_indices);
638
639        let score = 0.8;
640        scores.push(score);
641    }
642
643    let result = MultiLabelValidationResult::new(&validator, &splits);
644
645    Ok((scores, result))
646}
647
648#[allow(non_snake_case)]
649#[cfg(test)]
650mod tests {
651    use super::*;
652    use scirs2_core::ndarray::{arr2, Array2};
653
654    fn create_test_multilabel_data() -> Array2<i32> {
655        // Create data with repeated label combinations for better fold balance
656        arr2(&[
657            [1, 0, 1, 0], // Combination A
658            [1, 0, 1, 0], // Combination A (repeat)
659            [0, 1, 1, 0], // Combination B
660            [0, 1, 1, 0], // Combination B (repeat)
661            [1, 1, 0, 0], // Combination C
662            [1, 1, 0, 0], // Combination C (repeat)
663            [0, 0, 1, 1], // Combination D
664            [0, 0, 1, 1], // Combination D (repeat)
665        ])
666    }
667
668    #[test]
669    fn test_iterative_stratification() {
670        let y = create_test_multilabel_data();
671        let config = MultiLabelValidationConfig {
672            strategy: MultiLabelStrategy::IterativeStratification,
673            n_folds: 3,
674            random_state: Some(42),
675            ..Default::default()
676        };
677
678        let mut validator = MultiLabelCrossValidator::new(config);
679        let splits = validator.split(&y.view()).unwrap();
680
681        assert_eq!(splits.len(), 3);
682
683        for split in &splits {
684            assert!(!split.train_indices.is_empty());
685            assert!(!split.test_indices.is_empty());
686            assert_eq!(split.train_label_distribution.len(), 4);
687            assert_eq!(split.test_label_distribution.len(), 4);
688
689            let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
690            let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
691            assert!(train_set.is_disjoint(&test_set));
692        }
693    }
694
695    #[test]
696    fn test_label_powerset() {
697        let y = create_test_multilabel_data();
698        let config = MultiLabelValidationConfig {
699            strategy: MultiLabelStrategy::LabelPowerset,
700            n_folds: 2, // Reduce to 2 folds for better compatibility with small diverse dataset
701            random_state: Some(42),
702            ..Default::default()
703        };
704
705        let mut validator = MultiLabelCrossValidator::new(config);
706        let splits = validator.split(&y.view()).unwrap();
707
708        assert_eq!(splits.len(), 2);
709
710        for split in &splits {
711            assert!(!split.train_indices.is_empty());
712            assert!(!split.test_indices.is_empty());
713        }
714    }
715
716    #[test]
717    fn test_multilabel_kfold() {
718        let y = create_test_multilabel_data();
719        let config = MultiLabelValidationConfig {
720            strategy: MultiLabelStrategy::MultilabelKFold,
721            n_folds: 4,
722            random_state: Some(42),
723            ..Default::default()
724        };
725
726        let mut validator = MultiLabelCrossValidator::new(config);
727        let splits = validator.split(&y.view()).unwrap();
728
729        assert_eq!(splits.len(), 4);
730
731        let total_samples: HashSet<usize> = (0..8).collect();
732        for split in &splits {
733            let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
734            let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
735
736            assert!(train_set.is_disjoint(&test_set));
737            let union: HashSet<usize> = train_set.union(&test_set).cloned().collect();
738            assert_eq!(union, total_samples);
739        }
740    }
741
742    #[test]
743    fn test_label_statistics() {
744        let y = create_test_multilabel_data();
745        let config = MultiLabelValidationConfig::default();
746
747        let mut validator = MultiLabelCrossValidator::new(config);
748        validator.fit(&y.view()).unwrap();
749
750        let stats = validator.get_label_statistics().unwrap();
751        assert_eq!(stats.label_frequencies.len(), 4);
752        assert_eq!(stats.label_proportions.len(), 4);
753        assert!(stats.mean_labels_per_sample > 0.0);
754        assert!(stats.label_cardinality > 0.0);
755        assert!(stats.label_density > 0.0 && stats.label_density <= 1.0);
756    }
757
758    #[test]
759    fn test_insufficient_samples() {
760        let y = arr2(&[[1, 0], [0, 1]]);
761        let config = MultiLabelValidationConfig {
762            n_folds: 5,
763            ..Default::default()
764        };
765
766        let mut validator = MultiLabelCrossValidator::new(config);
767        let result = validator.split(&y.view());
768
769        assert!(result.is_err());
770    }
771
772    #[test]
773    fn test_empty_labels() {
774        let y = Array2::<i32>::zeros((0, 0));
775        let config = MultiLabelValidationConfig::default();
776
777        let mut validator = MultiLabelCrossValidator::new(config);
778        let result = validator.fit(&y.view());
779
780        assert!(result.is_err());
781    }
782}