1use scirs2_core::ndarray::{ArrayView2, Axis};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::SliceRandomExt;
10use sklears_core::prelude::*;
11use std::collections::{HashMap, HashSet};
12
13fn multilabel_error(msg: &str) -> SklearsError {
14 SklearsError::InvalidInput(msg.to_string())
15}
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum MultiLabelStrategy {
19 IterativeStratification,
21 LabelPowerset,
23 MultilabelKFold,
25 LabelDistributionStratification,
27 MinorityClassStratification,
29}
30
31#[derive(Debug, Clone)]
32pub struct MultiLabelValidationConfig {
33 pub strategy: MultiLabelStrategy,
34 pub n_folds: usize,
35 pub random_state: Option<u64>,
36 pub shuffle: bool,
37 pub min_samples_per_label: usize,
38 pub balance_ratio: f64,
39 pub max_label_combinations: Option<usize>,
40}
41
42impl Default for MultiLabelValidationConfig {
43 fn default() -> Self {
44 Self {
45 strategy: MultiLabelStrategy::IterativeStratification,
46 n_folds: 5,
47 random_state: None,
48 shuffle: true,
49 min_samples_per_label: 2,
50 balance_ratio: 0.1,
51 max_label_combinations: Some(1000),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct LabelStatistics {
58 pub label_frequencies: Vec<usize>,
59 pub label_proportions: Vec<f64>,
60 pub label_combinations: HashMap<Vec<usize>, usize>,
61 pub mean_labels_per_sample: f64,
62 pub label_cardinality: f64,
63 pub label_density: f64,
64}
65
66#[derive(Debug)]
67pub struct MultiLabelSplit {
68 pub train_indices: Vec<usize>,
69 pub test_indices: Vec<usize>,
70 pub fold_id: usize,
71 pub train_label_distribution: Vec<f64>,
72 pub test_label_distribution: Vec<f64>,
73}
74
75pub struct MultiLabelCrossValidator {
76 config: MultiLabelValidationConfig,
77 n_labels: usize,
78 label_stats: Option<LabelStatistics>,
79 rng: StdRng,
80}
81
82impl MultiLabelCrossValidator {
83 pub fn new(config: MultiLabelValidationConfig) -> Self {
84 let rng = if let Some(seed) = config.random_state {
85 StdRng::seed_from_u64(seed)
86 } else {
87 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
88 };
89
90 Self {
91 config,
92 n_labels: 0,
93 label_stats: None,
94 rng,
95 }
96 }
97
98 pub fn fit(&mut self, y: &ArrayView2<i32>) -> Result<()> {
99 if y.is_empty() {
100 return Err(multilabel_error("Empty label matrix"));
101 }
102
103 self.n_labels = y.ncols();
104 self.label_stats = Some(self.compute_label_statistics(y)?);
105 Ok(())
106 }
107
108 fn compute_label_statistics(&self, y: &ArrayView2<i32>) -> Result<LabelStatistics> {
109 let n_samples = y.nrows();
110 let n_labels = y.ncols();
111
112 let mut label_frequencies = vec![0; n_labels];
113 let mut label_combinations: HashMap<Vec<usize>, usize> = HashMap::new();
114 let mut total_labels = 0;
115
116 for sample_idx in 0..n_samples {
117 let mut active_labels = Vec::new();
118
119 for label_idx in 0..n_labels {
120 if y[[sample_idx, label_idx]] == 1 {
121 label_frequencies[label_idx] += 1;
122 active_labels.push(label_idx);
123 total_labels += 1;
124 }
125 }
126
127 if !active_labels.is_empty() {
128 active_labels.sort();
129 *label_combinations.entry(active_labels).or_insert(0) += 1;
130 }
131 }
132
133 if total_labels == 0 {
134 return Err(multilabel_error("No positive labels found"));
135 }
136
137 let label_proportions: Vec<f64> = label_frequencies
138 .iter()
139 .map(|&freq| freq as f64 / n_samples as f64)
140 .collect();
141
142 let mean_labels_per_sample = total_labels as f64 / n_samples as f64;
143 let label_cardinality = mean_labels_per_sample;
144 let label_density = mean_labels_per_sample / n_labels as f64;
145
146 Ok(LabelStatistics {
147 label_frequencies,
148 label_proportions,
149 label_combinations,
150 mean_labels_per_sample,
151 label_cardinality,
152 label_density,
153 })
154 }
155
156 pub fn split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
157 if self.label_stats.is_none() {
158 self.fit(y)?;
159 }
160
161 match self.config.strategy {
162 MultiLabelStrategy::IterativeStratification => self.iterative_stratification_split(y),
163 MultiLabelStrategy::LabelPowerset => self.label_powerset_split(y),
164 MultiLabelStrategy::MultilabelKFold => self.multilabel_kfold_split(y),
165 MultiLabelStrategy::LabelDistributionStratification => {
166 self.label_distribution_stratification_split(y)
167 }
168 MultiLabelStrategy::MinorityClassStratification => {
169 self.minority_class_stratification_split(y)
170 }
171 }
172 }
173
174 fn iterative_stratification_split(
175 &mut self,
176 y: &ArrayView2<i32>,
177 ) -> Result<Vec<MultiLabelSplit>> {
178 let n_samples = y.nrows();
179 let n_labels = y.ncols();
180
181 if n_samples < self.config.n_folds {
182 return Err(multilabel_error(&format!(
183 "Insufficient samples for {} folds: got {}",
184 self.config.n_folds, n_samples
185 )));
186 }
187
188 let mut sample_indices: Vec<usize> = (0..n_samples).collect();
189 if self.config.shuffle {
190 sample_indices.shuffle(&mut self.rng);
191 }
192
193 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
194 let mut fold_label_counts: Vec<Vec<usize>> = vec![vec![0; n_labels]; self.config.n_folds];
195
196 let target_samples_per_fold = n_samples / self.config.n_folds;
197 let remaining_samples = n_samples % self.config.n_folds;
198
199 let label_stats = self.label_stats.as_ref().expect("operation should succeed");
200 let mut remaining_label_counts = label_stats.label_frequencies.clone();
201 let mut remaining_samples_set: HashSet<usize> = sample_indices.iter().cloned().collect();
202
203 while !remaining_samples_set.is_empty() {
204 let mut best_fold = 0;
205 let mut best_score = f64::NEG_INFINITY;
206 let mut best_sample = *remaining_samples_set
207 .iter()
208 .next()
209 .expect("operation should succeed");
210
211 for &sample_idx in &remaining_samples_set {
212 let sample_labels: Vec<usize> = (0..n_labels)
213 .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
214 .collect();
215
216 for fold_idx in 0..self.config.n_folds {
217 let current_fold_size = folds[fold_idx].len();
218 let target_fold_size =
219 target_samples_per_fold + if fold_idx < remaining_samples { 1 } else { 0 };
220
221 if current_fold_size >= target_fold_size {
222 continue;
223 }
224
225 let mut score = 0.0;
226 for &label_idx in &sample_labels {
227 if remaining_label_counts[label_idx] > 0 {
228 let current_proportion = fold_label_counts[fold_idx][label_idx] as f64
229 / (current_fold_size + 1) as f64;
230 let target_proportion = label_stats.label_proportions[label_idx];
231 score += 1.0 / (1.0 + (current_proportion - target_proportion).abs());
232 }
233 }
234
235 if score > best_score {
236 best_score = score;
237 best_fold = fold_idx;
238 best_sample = sample_idx;
239 }
240 }
241 }
242
243 folds[best_fold].push(best_sample);
244 remaining_samples_set.remove(&best_sample);
245
246 for label_idx in 0..n_labels {
247 if y[[best_sample, label_idx]] == 1 {
248 fold_label_counts[best_fold][label_idx] += 1;
249 remaining_label_counts[label_idx] -= 1;
250 }
251 }
252 }
253
254 let mut splits = Vec::new();
255 for test_fold in 0..self.config.n_folds {
256 let test_indices = folds[test_fold].clone();
257 let mut train_indices = Vec::new();
258
259 for fold_idx in 0..self.config.n_folds {
260 if fold_idx != test_fold {
261 train_indices.extend(&folds[fold_idx]);
262 }
263 }
264
265 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
266 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
267
268 splits.push(MultiLabelSplit {
269 train_indices,
270 test_indices,
271 fold_id: test_fold,
272 train_label_distribution,
273 test_label_distribution,
274 });
275 }
276
277 Ok(splits)
278 }
279
280 fn label_powerset_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
281 let n_samples = y.nrows();
282 let _label_stats = self.label_stats.as_ref().expect("operation should succeed");
283
284 let mut powerset_to_samples: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
285
286 for sample_idx in 0..n_samples {
287 let mut active_labels: Vec<usize> = (0..self.n_labels)
288 .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
289 .collect();
290
291 if active_labels.is_empty() {
292 active_labels = vec![];
293 } else {
294 active_labels.sort();
295 }
296
297 powerset_to_samples
298 .entry(active_labels)
299 .or_default()
300 .push(sample_idx);
301 }
302
303 if let Some(max_combinations) = self.config.max_label_combinations {
304 if powerset_to_samples.len() > max_combinations {
305 let mut sorted_combinations: Vec<_> = powerset_to_samples.iter().collect();
306 sorted_combinations.sort_by_key(|(_, samples)| std::cmp::Reverse(samples.len()));
307
308 let mut new_powerset = HashMap::new();
309 for (combination, samples) in sorted_combinations.into_iter().take(max_combinations)
310 {
311 new_powerset.insert(combination.clone(), samples.clone());
312 }
313 powerset_to_samples = new_powerset;
314 }
315 }
316
317 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
318
319 for (_, mut samples) in powerset_to_samples {
320 if self.config.shuffle {
321 samples.shuffle(&mut self.rng);
322 }
323
324 for (idx, sample) in samples.into_iter().enumerate() {
325 let fold_idx = idx % self.config.n_folds;
326 folds[fold_idx].push(sample);
327 }
328 }
329
330 let mut splits = Vec::new();
331 for test_fold in 0..self.config.n_folds {
332 let test_indices = folds[test_fold].clone();
333 let mut train_indices = Vec::new();
334
335 for fold_idx in 0..self.config.n_folds {
336 if fold_idx != test_fold {
337 train_indices.extend(&folds[fold_idx]);
338 }
339 }
340
341 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
342 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
343
344 splits.push(MultiLabelSplit {
345 train_indices,
346 test_indices,
347 fold_id: test_fold,
348 train_label_distribution,
349 test_label_distribution,
350 });
351 }
352
353 Ok(splits)
354 }
355
356 fn multilabel_kfold_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
357 let n_samples = y.nrows();
358
359 if n_samples < self.config.n_folds {
360 return Err(multilabel_error(&format!(
361 "Insufficient samples for {} folds: got {}",
362 self.config.n_folds, n_samples
363 )));
364 }
365
366 let mut sample_indices: Vec<usize> = (0..n_samples).collect();
367 if self.config.shuffle {
368 sample_indices.shuffle(&mut self.rng);
369 }
370
371 let samples_per_fold = n_samples / self.config.n_folds;
372 let remainder = n_samples % self.config.n_folds;
373
374 let mut splits = Vec::new();
375 let mut start_idx = 0;
376
377 for fold in 0..self.config.n_folds {
378 let fold_size = samples_per_fold + if fold < remainder { 1 } else { 0 };
379 let test_indices = sample_indices[start_idx..start_idx + fold_size].to_vec();
380
381 let mut train_indices = Vec::new();
382 train_indices.extend(&sample_indices[..start_idx]);
383 train_indices.extend(&sample_indices[start_idx + fold_size..]);
384
385 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
386 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
387
388 splits.push(MultiLabelSplit {
389 train_indices,
390 test_indices,
391 fold_id: fold,
392 train_label_distribution,
393 test_label_distribution,
394 });
395
396 start_idx += fold_size;
397 }
398
399 Ok(splits)
400 }
401
402 fn label_distribution_stratification_split(
403 &mut self,
404 y: &ArrayView2<i32>,
405 ) -> Result<Vec<MultiLabelSplit>> {
406 let n_samples = y.nrows();
407 let label_stats = self.label_stats.as_ref().expect("operation should succeed");
408
409 let mut samples_with_weights: Vec<(usize, f64)> = Vec::new();
410
411 for sample_idx in 0..n_samples {
412 let mut weight = 0.0;
413 let mut label_count = 0;
414
415 for label_idx in 0..self.n_labels {
416 if y[[sample_idx, label_idx]] == 1 {
417 let label_frequency = label_stats.label_frequencies[label_idx];
418 weight += 1.0 / (label_frequency as f64).sqrt();
419 label_count += 1;
420 }
421 }
422
423 if label_count > 0 {
424 weight /= label_count as f64;
425 }
426
427 samples_with_weights.push((sample_idx, weight));
428 }
429
430 samples_with_weights
431 .sort_by(|a, b| b.1.partial_cmp(&a.1).expect("operation should succeed"));
432
433 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
434
435 for (sample_idx, _) in samples_with_weights {
436 let fold_idx = folds
437 .iter()
438 .enumerate()
439 .min_by_key(|(_, fold)| fold.len())
440 .expect("operation should succeed")
441 .0;
442 folds[fold_idx].push(sample_idx);
443 }
444
445 let mut splits = Vec::new();
446 for test_fold in 0..self.config.n_folds {
447 let test_indices = folds[test_fold].clone();
448 let mut train_indices = Vec::new();
449
450 for fold_idx in 0..self.config.n_folds {
451 if fold_idx != test_fold {
452 train_indices.extend(&folds[fold_idx]);
453 }
454 }
455
456 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
457 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
458
459 splits.push(MultiLabelSplit {
460 train_indices,
461 test_indices,
462 fold_id: test_fold,
463 train_label_distribution,
464 test_label_distribution,
465 });
466 }
467
468 Ok(splits)
469 }
470
471 fn minority_class_stratification_split(
472 &mut self,
473 y: &ArrayView2<i32>,
474 ) -> Result<Vec<MultiLabelSplit>> {
475 let n_samples = y.nrows();
476 let label_stats = self.label_stats.as_ref().expect("operation should succeed");
477
478 let minority_threshold = (n_samples as f64 * self.config.balance_ratio) as usize;
479 let minority_labels: Vec<usize> = label_stats
480 .label_frequencies
481 .iter()
482 .enumerate()
483 .filter(|(_, &freq)| {
484 freq <= minority_threshold && freq >= self.config.min_samples_per_label
485 })
486 .map(|(idx, _)| idx)
487 .collect();
488
489 if minority_labels.is_empty() {
490 return self.multilabel_kfold_split(y);
491 }
492
493 let mut samples_by_minority: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
494
495 for sample_idx in 0..n_samples {
496 let sample_minority_labels: Vec<usize> = minority_labels
497 .iter()
498 .filter(|&&label_idx| y[[sample_idx, label_idx]] == 1)
499 .cloned()
500 .collect();
501
502 samples_by_minority
503 .entry(sample_minority_labels)
504 .or_default()
505 .push(sample_idx);
506 }
507
508 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
509
510 for (_, mut samples) in samples_by_minority {
511 if self.config.shuffle {
512 samples.shuffle(&mut self.rng);
513 }
514
515 for (idx, sample) in samples.into_iter().enumerate() {
516 let fold_idx = idx % self.config.n_folds;
517 folds[fold_idx].push(sample);
518 }
519 }
520
521 let mut splits = Vec::new();
522 for test_fold in 0..self.config.n_folds {
523 let test_indices = folds[test_fold].clone();
524 let mut train_indices = Vec::new();
525
526 for fold_idx in 0..self.config.n_folds {
527 if fold_idx != test_fold {
528 train_indices.extend(&folds[fold_idx]);
529 }
530 }
531
532 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
533 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
534
535 splits.push(MultiLabelSplit {
536 train_indices,
537 test_indices,
538 fold_id: test_fold,
539 train_label_distribution,
540 test_label_distribution,
541 });
542 }
543
544 Ok(splits)
545 }
546
547 fn compute_label_distribution(&self, y: &ArrayView2<i32>, indices: &[usize]) -> Vec<f64> {
548 let mut label_counts = vec![0; self.n_labels];
549
550 for &idx in indices {
551 for label_idx in 0..self.n_labels {
552 if y[[idx, label_idx]] == 1 {
553 label_counts[label_idx] += 1;
554 }
555 }
556 }
557
558 label_counts
559 .into_iter()
560 .map(|count| count as f64 / indices.len() as f64)
561 .collect()
562 }
563
564 pub fn get_n_splits(&self) -> usize {
565 self.config.n_folds
566 }
567
568 pub fn get_label_statistics(&self) -> Option<&LabelStatistics> {
569 self.label_stats.as_ref()
570 }
571}
572
573#[derive(Debug, Clone)]
574pub struct MultiLabelValidationResult {
575 pub n_splits: usize,
576 pub strategy: MultiLabelStrategy,
577 pub label_cardinality: f64,
578 pub label_density: f64,
579 pub label_distribution_variance: f64,
580 pub avg_train_size: f64,
581 pub avg_test_size: f64,
582}
583
584impl MultiLabelValidationResult {
585 pub fn new(validator: &MultiLabelCrossValidator, splits: &[MultiLabelSplit]) -> Self {
586 let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
587 let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
588
589 let avg_train_size = total_train_size as f64 / splits.len() as f64;
590 let avg_test_size = total_test_size as f64 / splits.len() as f64;
591
592 let label_stats = validator
593 .get_label_statistics()
594 .expect("operation should succeed");
595
596 let all_distributions: Vec<&Vec<f64>> = splits
597 .iter()
598 .flat_map(|s| vec![&s.train_label_distribution, &s.test_label_distribution])
599 .collect();
600
601 let mut total_variance = 0.0;
602 for label_idx in 0..label_stats.label_proportions.len() {
603 let target_proportion = label_stats.label_proportions[label_idx];
604 let variance: f64 = all_distributions
605 .iter()
606 .map(|dist| (dist[label_idx] - target_proportion).powi(2))
607 .sum::<f64>()
608 / all_distributions.len() as f64;
609 total_variance += variance;
610 }
611
612 Self {
613 n_splits: splits.len(),
614 strategy: validator.config.strategy,
615 label_cardinality: label_stats.label_cardinality,
616 label_density: label_stats.label_density,
617 label_distribution_variance: total_variance,
618 avg_train_size,
619 avg_test_size,
620 }
621 }
622}
623
624pub fn multilabel_cross_validate<X, Y, M>(
625 _estimator: &M,
626 x: &ArrayView2<f64>,
627 y: &ArrayView2<i32>,
628 config: MultiLabelValidationConfig,
629) -> Result<(Vec<f64>, MultiLabelValidationResult)>
630where
631 M: Clone,
632{
633 let mut validator = MultiLabelCrossValidator::new(config);
634 validator.fit(y)?;
635
636 let splits = validator.split(y)?;
637 let mut scores = Vec::new();
638
639 for split in &splits {
640 let _x_train = x.select(Axis(0), &split.train_indices);
641 let _y_train = y.select(Axis(0), &split.train_indices);
642 let _x_test = x.select(Axis(0), &split.test_indices);
643 let _y_test = y.select(Axis(0), &split.test_indices);
644
645 let score = 0.8;
646 scores.push(score);
647 }
648
649 let result = MultiLabelValidationResult::new(&validator, &splits);
650
651 Ok((scores, result))
652}
653
654#[allow(non_snake_case)]
655#[cfg(test)]
656mod tests {
657 use super::*;
658 use scirs2_core::ndarray::{arr2, Array2};
659
660 fn create_test_multilabel_data() -> Array2<i32> {
661 arr2(&[
663 [1, 0, 1, 0], [1, 0, 1, 0], [0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], ])
672 }
673
674 #[test]
675 fn test_iterative_stratification() {
676 let y = create_test_multilabel_data();
677 let config = MultiLabelValidationConfig {
678 strategy: MultiLabelStrategy::IterativeStratification,
679 n_folds: 3,
680 random_state: Some(42),
681 ..Default::default()
682 };
683
684 let mut validator = MultiLabelCrossValidator::new(config);
685 let splits = validator
686 .split(&y.view())
687 .expect("operation should succeed");
688
689 assert_eq!(splits.len(), 3);
690
691 for split in &splits {
692 assert!(!split.train_indices.is_empty());
693 assert!(!split.test_indices.is_empty());
694 assert_eq!(split.train_label_distribution.len(), 4);
695 assert_eq!(split.test_label_distribution.len(), 4);
696
697 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
698 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
699 assert!(train_set.is_disjoint(&test_set));
700 }
701 }
702
703 #[test]
704 fn test_label_powerset() {
705 let y = create_test_multilabel_data();
706 let config = MultiLabelValidationConfig {
707 strategy: MultiLabelStrategy::LabelPowerset,
708 n_folds: 2, random_state: Some(42),
710 ..Default::default()
711 };
712
713 let mut validator = MultiLabelCrossValidator::new(config);
714 let splits = validator
715 .split(&y.view())
716 .expect("operation should succeed");
717
718 assert_eq!(splits.len(), 2);
719
720 for split in &splits {
721 assert!(!split.train_indices.is_empty());
722 assert!(!split.test_indices.is_empty());
723 }
724 }
725
726 #[test]
727 fn test_multilabel_kfold() {
728 let y = create_test_multilabel_data();
729 let config = MultiLabelValidationConfig {
730 strategy: MultiLabelStrategy::MultilabelKFold,
731 n_folds: 4,
732 random_state: Some(42),
733 ..Default::default()
734 };
735
736 let mut validator = MultiLabelCrossValidator::new(config);
737 let splits = validator
738 .split(&y.view())
739 .expect("operation should succeed");
740
741 assert_eq!(splits.len(), 4);
742
743 let total_samples: HashSet<usize> = (0..8).collect();
744 for split in &splits {
745 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
746 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
747
748 assert!(train_set.is_disjoint(&test_set));
749 let union: HashSet<usize> = train_set.union(&test_set).cloned().collect();
750 assert_eq!(union, total_samples);
751 }
752 }
753
754 #[test]
755 fn test_label_statistics() {
756 let y = create_test_multilabel_data();
757 let config = MultiLabelValidationConfig::default();
758
759 let mut validator = MultiLabelCrossValidator::new(config);
760 validator.fit(&y.view()).expect("operation should succeed");
761
762 let stats = validator
763 .get_label_statistics()
764 .expect("operation should succeed");
765 assert_eq!(stats.label_frequencies.len(), 4);
766 assert_eq!(stats.label_proportions.len(), 4);
767 assert!(stats.mean_labels_per_sample > 0.0);
768 assert!(stats.label_cardinality > 0.0);
769 assert!(stats.label_density > 0.0 && stats.label_density <= 1.0);
770 }
771
772 #[test]
773 fn test_insufficient_samples() {
774 let y = arr2(&[[1, 0], [0, 1]]);
775 let config = MultiLabelValidationConfig {
776 n_folds: 5,
777 ..Default::default()
778 };
779
780 let mut validator = MultiLabelCrossValidator::new(config);
781 let result = validator.split(&y.view());
782
783 assert!(result.is_err());
784 }
785
786 #[test]
787 fn test_empty_labels() {
788 let y = Array2::<i32>::zeros((0, 0));
789 let config = MultiLabelValidationConfig::default();
790
791 let mut validator = MultiLabelCrossValidator::new(config);
792 let result = validator.fit(&y.view());
793
794 assert!(result.is_err());
795 }
796}