1use scirs2_core::ndarray::{ArrayView2, Axis};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::SliceRandomExt;
10use sklears_core::prelude::*;
11use std::collections::{HashMap, HashSet};
12
13fn multilabel_error(msg: &str) -> SklearsError {
14 SklearsError::InvalidInput(msg.to_string())
15}
16
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum MultiLabelStrategy {
19 IterativeStratification,
21 LabelPowerset,
23 MultilabelKFold,
25 LabelDistributionStratification,
27 MinorityClassStratification,
29}
30
31#[derive(Debug, Clone)]
32pub struct MultiLabelValidationConfig {
33 pub strategy: MultiLabelStrategy,
34 pub n_folds: usize,
35 pub random_state: Option<u64>,
36 pub shuffle: bool,
37 pub min_samples_per_label: usize,
38 pub balance_ratio: f64,
39 pub max_label_combinations: Option<usize>,
40}
41
42impl Default for MultiLabelValidationConfig {
43 fn default() -> Self {
44 Self {
45 strategy: MultiLabelStrategy::IterativeStratification,
46 n_folds: 5,
47 random_state: None,
48 shuffle: true,
49 min_samples_per_label: 2,
50 balance_ratio: 0.1,
51 max_label_combinations: Some(1000),
52 }
53 }
54}
55
56#[derive(Debug, Clone)]
57pub struct LabelStatistics {
58 pub label_frequencies: Vec<usize>,
59 pub label_proportions: Vec<f64>,
60 pub label_combinations: HashMap<Vec<usize>, usize>,
61 pub mean_labels_per_sample: f64,
62 pub label_cardinality: f64,
63 pub label_density: f64,
64}
65
66#[derive(Debug)]
67pub struct MultiLabelSplit {
68 pub train_indices: Vec<usize>,
69 pub test_indices: Vec<usize>,
70 pub fold_id: usize,
71 pub train_label_distribution: Vec<f64>,
72 pub test_label_distribution: Vec<f64>,
73}
74
75pub struct MultiLabelCrossValidator {
76 config: MultiLabelValidationConfig,
77 n_labels: usize,
78 label_stats: Option<LabelStatistics>,
79 rng: StdRng,
80}
81
82impl MultiLabelCrossValidator {
83 pub fn new(config: MultiLabelValidationConfig) -> Self {
84 let rng = if let Some(seed) = config.random_state {
85 StdRng::seed_from_u64(seed)
86 } else {
87 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
88 };
89
90 Self {
91 config,
92 n_labels: 0,
93 label_stats: None,
94 rng,
95 }
96 }
97
98 pub fn fit(&mut self, y: &ArrayView2<i32>) -> Result<()> {
99 if y.is_empty() {
100 return Err(multilabel_error("Empty label matrix"));
101 }
102
103 self.n_labels = y.ncols();
104 self.label_stats = Some(self.compute_label_statistics(y)?);
105 Ok(())
106 }
107
108 fn compute_label_statistics(&self, y: &ArrayView2<i32>) -> Result<LabelStatistics> {
109 let n_samples = y.nrows();
110 let n_labels = y.ncols();
111
112 let mut label_frequencies = vec![0; n_labels];
113 let mut label_combinations: HashMap<Vec<usize>, usize> = HashMap::new();
114 let mut total_labels = 0;
115
116 for sample_idx in 0..n_samples {
117 let mut active_labels = Vec::new();
118
119 for label_idx in 0..n_labels {
120 if y[[sample_idx, label_idx]] == 1 {
121 label_frequencies[label_idx] += 1;
122 active_labels.push(label_idx);
123 total_labels += 1;
124 }
125 }
126
127 if !active_labels.is_empty() {
128 active_labels.sort();
129 *label_combinations.entry(active_labels).or_insert(0) += 1;
130 }
131 }
132
133 if total_labels == 0 {
134 return Err(multilabel_error("No positive labels found"));
135 }
136
137 let label_proportions: Vec<f64> = label_frequencies
138 .iter()
139 .map(|&freq| freq as f64 / n_samples as f64)
140 .collect();
141
142 let mean_labels_per_sample = total_labels as f64 / n_samples as f64;
143 let label_cardinality = mean_labels_per_sample;
144 let label_density = mean_labels_per_sample / n_labels as f64;
145
146 Ok(LabelStatistics {
147 label_frequencies,
148 label_proportions,
149 label_combinations,
150 mean_labels_per_sample,
151 label_cardinality,
152 label_density,
153 })
154 }
155
156 pub fn split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
157 if self.label_stats.is_none() {
158 self.fit(y)?;
159 }
160
161 match self.config.strategy {
162 MultiLabelStrategy::IterativeStratification => self.iterative_stratification_split(y),
163 MultiLabelStrategy::LabelPowerset => self.label_powerset_split(y),
164 MultiLabelStrategy::MultilabelKFold => self.multilabel_kfold_split(y),
165 MultiLabelStrategy::LabelDistributionStratification => {
166 self.label_distribution_stratification_split(y)
167 }
168 MultiLabelStrategy::MinorityClassStratification => {
169 self.minority_class_stratification_split(y)
170 }
171 }
172 }
173
174 fn iterative_stratification_split(
175 &mut self,
176 y: &ArrayView2<i32>,
177 ) -> Result<Vec<MultiLabelSplit>> {
178 let n_samples = y.nrows();
179 let n_labels = y.ncols();
180
181 if n_samples < self.config.n_folds {
182 return Err(multilabel_error(&format!(
183 "Insufficient samples for {} folds: got {}",
184 self.config.n_folds, n_samples
185 )));
186 }
187
188 let mut sample_indices: Vec<usize> = (0..n_samples).collect();
189 if self.config.shuffle {
190 sample_indices.shuffle(&mut self.rng);
191 }
192
193 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
194 let mut fold_label_counts: Vec<Vec<usize>> = vec![vec![0; n_labels]; self.config.n_folds];
195
196 let target_samples_per_fold = n_samples / self.config.n_folds;
197 let remaining_samples = n_samples % self.config.n_folds;
198
199 let label_stats = self.label_stats.as_ref().unwrap();
200 let mut remaining_label_counts = label_stats.label_frequencies.clone();
201 let mut remaining_samples_set: HashSet<usize> = sample_indices.iter().cloned().collect();
202
203 while !remaining_samples_set.is_empty() {
204 let mut best_fold = 0;
205 let mut best_score = f64::NEG_INFINITY;
206 let mut best_sample = *remaining_samples_set.iter().next().unwrap();
207
208 for &sample_idx in &remaining_samples_set {
209 let sample_labels: Vec<usize> = (0..n_labels)
210 .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
211 .collect();
212
213 for fold_idx in 0..self.config.n_folds {
214 let current_fold_size = folds[fold_idx].len();
215 let target_fold_size =
216 target_samples_per_fold + if fold_idx < remaining_samples { 1 } else { 0 };
217
218 if current_fold_size >= target_fold_size {
219 continue;
220 }
221
222 let mut score = 0.0;
223 for &label_idx in &sample_labels {
224 if remaining_label_counts[label_idx] > 0 {
225 let current_proportion = fold_label_counts[fold_idx][label_idx] as f64
226 / (current_fold_size + 1) as f64;
227 let target_proportion = label_stats.label_proportions[label_idx];
228 score += 1.0 / (1.0 + (current_proportion - target_proportion).abs());
229 }
230 }
231
232 if score > best_score {
233 best_score = score;
234 best_fold = fold_idx;
235 best_sample = sample_idx;
236 }
237 }
238 }
239
240 folds[best_fold].push(best_sample);
241 remaining_samples_set.remove(&best_sample);
242
243 for label_idx in 0..n_labels {
244 if y[[best_sample, label_idx]] == 1 {
245 fold_label_counts[best_fold][label_idx] += 1;
246 remaining_label_counts[label_idx] -= 1;
247 }
248 }
249 }
250
251 let mut splits = Vec::new();
252 for test_fold in 0..self.config.n_folds {
253 let test_indices = folds[test_fold].clone();
254 let mut train_indices = Vec::new();
255
256 for fold_idx in 0..self.config.n_folds {
257 if fold_idx != test_fold {
258 train_indices.extend(&folds[fold_idx]);
259 }
260 }
261
262 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
263 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
264
265 splits.push(MultiLabelSplit {
266 train_indices,
267 test_indices,
268 fold_id: test_fold,
269 train_label_distribution,
270 test_label_distribution,
271 });
272 }
273
274 Ok(splits)
275 }
276
277 fn label_powerset_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
278 let n_samples = y.nrows();
279 let _label_stats = self.label_stats.as_ref().unwrap();
280
281 let mut powerset_to_samples: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
282
283 for sample_idx in 0..n_samples {
284 let mut active_labels: Vec<usize> = (0..self.n_labels)
285 .filter(|&label_idx| y[[sample_idx, label_idx]] == 1)
286 .collect();
287
288 if active_labels.is_empty() {
289 active_labels = vec![];
290 } else {
291 active_labels.sort();
292 }
293
294 powerset_to_samples
295 .entry(active_labels)
296 .or_default()
297 .push(sample_idx);
298 }
299
300 if let Some(max_combinations) = self.config.max_label_combinations {
301 if powerset_to_samples.len() > max_combinations {
302 let mut sorted_combinations: Vec<_> = powerset_to_samples.iter().collect();
303 sorted_combinations.sort_by_key(|(_, samples)| std::cmp::Reverse(samples.len()));
304
305 let mut new_powerset = HashMap::new();
306 for (combination, samples) in sorted_combinations.into_iter().take(max_combinations)
307 {
308 new_powerset.insert(combination.clone(), samples.clone());
309 }
310 powerset_to_samples = new_powerset;
311 }
312 }
313
314 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
315
316 for (_, mut samples) in powerset_to_samples {
317 if self.config.shuffle {
318 samples.shuffle(&mut self.rng);
319 }
320
321 for (idx, sample) in samples.into_iter().enumerate() {
322 let fold_idx = idx % self.config.n_folds;
323 folds[fold_idx].push(sample);
324 }
325 }
326
327 let mut splits = Vec::new();
328 for test_fold in 0..self.config.n_folds {
329 let test_indices = folds[test_fold].clone();
330 let mut train_indices = Vec::new();
331
332 for fold_idx in 0..self.config.n_folds {
333 if fold_idx != test_fold {
334 train_indices.extend(&folds[fold_idx]);
335 }
336 }
337
338 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
339 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
340
341 splits.push(MultiLabelSplit {
342 train_indices,
343 test_indices,
344 fold_id: test_fold,
345 train_label_distribution,
346 test_label_distribution,
347 });
348 }
349
350 Ok(splits)
351 }
352
353 fn multilabel_kfold_split(&mut self, y: &ArrayView2<i32>) -> Result<Vec<MultiLabelSplit>> {
354 let n_samples = y.nrows();
355
356 if n_samples < self.config.n_folds {
357 return Err(multilabel_error(&format!(
358 "Insufficient samples for {} folds: got {}",
359 self.config.n_folds, n_samples
360 )));
361 }
362
363 let mut sample_indices: Vec<usize> = (0..n_samples).collect();
364 if self.config.shuffle {
365 sample_indices.shuffle(&mut self.rng);
366 }
367
368 let samples_per_fold = n_samples / self.config.n_folds;
369 let remainder = n_samples % self.config.n_folds;
370
371 let mut splits = Vec::new();
372 let mut start_idx = 0;
373
374 for fold in 0..self.config.n_folds {
375 let fold_size = samples_per_fold + if fold < remainder { 1 } else { 0 };
376 let test_indices = sample_indices[start_idx..start_idx + fold_size].to_vec();
377
378 let mut train_indices = Vec::new();
379 train_indices.extend(&sample_indices[..start_idx]);
380 train_indices.extend(&sample_indices[start_idx + fold_size..]);
381
382 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
383 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
384
385 splits.push(MultiLabelSplit {
386 train_indices,
387 test_indices,
388 fold_id: fold,
389 train_label_distribution,
390 test_label_distribution,
391 });
392
393 start_idx += fold_size;
394 }
395
396 Ok(splits)
397 }
398
399 fn label_distribution_stratification_split(
400 &mut self,
401 y: &ArrayView2<i32>,
402 ) -> Result<Vec<MultiLabelSplit>> {
403 let n_samples = y.nrows();
404 let label_stats = self.label_stats.as_ref().unwrap();
405
406 let mut samples_with_weights: Vec<(usize, f64)> = Vec::new();
407
408 for sample_idx in 0..n_samples {
409 let mut weight = 0.0;
410 let mut label_count = 0;
411
412 for label_idx in 0..self.n_labels {
413 if y[[sample_idx, label_idx]] == 1 {
414 let label_frequency = label_stats.label_frequencies[label_idx];
415 weight += 1.0 / (label_frequency as f64).sqrt();
416 label_count += 1;
417 }
418 }
419
420 if label_count > 0 {
421 weight /= label_count as f64;
422 }
423
424 samples_with_weights.push((sample_idx, weight));
425 }
426
427 samples_with_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
428
429 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
430
431 for (sample_idx, _) in samples_with_weights {
432 let fold_idx = folds
433 .iter()
434 .enumerate()
435 .min_by_key(|(_, fold)| fold.len())
436 .unwrap()
437 .0;
438 folds[fold_idx].push(sample_idx);
439 }
440
441 let mut splits = Vec::new();
442 for test_fold in 0..self.config.n_folds {
443 let test_indices = folds[test_fold].clone();
444 let mut train_indices = Vec::new();
445
446 for fold_idx in 0..self.config.n_folds {
447 if fold_idx != test_fold {
448 train_indices.extend(&folds[fold_idx]);
449 }
450 }
451
452 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
453 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
454
455 splits.push(MultiLabelSplit {
456 train_indices,
457 test_indices,
458 fold_id: test_fold,
459 train_label_distribution,
460 test_label_distribution,
461 });
462 }
463
464 Ok(splits)
465 }
466
467 fn minority_class_stratification_split(
468 &mut self,
469 y: &ArrayView2<i32>,
470 ) -> Result<Vec<MultiLabelSplit>> {
471 let n_samples = y.nrows();
472 let label_stats = self.label_stats.as_ref().unwrap();
473
474 let minority_threshold = (n_samples as f64 * self.config.balance_ratio) as usize;
475 let minority_labels: Vec<usize> = label_stats
476 .label_frequencies
477 .iter()
478 .enumerate()
479 .filter(|(_, &freq)| {
480 freq <= minority_threshold && freq >= self.config.min_samples_per_label
481 })
482 .map(|(idx, _)| idx)
483 .collect();
484
485 if minority_labels.is_empty() {
486 return self.multilabel_kfold_split(y);
487 }
488
489 let mut samples_by_minority: HashMap<Vec<usize>, Vec<usize>> = HashMap::new();
490
491 for sample_idx in 0..n_samples {
492 let sample_minority_labels: Vec<usize> = minority_labels
493 .iter()
494 .filter(|&&label_idx| y[[sample_idx, label_idx]] == 1)
495 .cloned()
496 .collect();
497
498 samples_by_minority
499 .entry(sample_minority_labels)
500 .or_default()
501 .push(sample_idx);
502 }
503
504 let mut folds: Vec<Vec<usize>> = vec![Vec::new(); self.config.n_folds];
505
506 for (_, mut samples) in samples_by_minority {
507 if self.config.shuffle {
508 samples.shuffle(&mut self.rng);
509 }
510
511 for (idx, sample) in samples.into_iter().enumerate() {
512 let fold_idx = idx % self.config.n_folds;
513 folds[fold_idx].push(sample);
514 }
515 }
516
517 let mut splits = Vec::new();
518 for test_fold in 0..self.config.n_folds {
519 let test_indices = folds[test_fold].clone();
520 let mut train_indices = Vec::new();
521
522 for fold_idx in 0..self.config.n_folds {
523 if fold_idx != test_fold {
524 train_indices.extend(&folds[fold_idx]);
525 }
526 }
527
528 let train_label_distribution = self.compute_label_distribution(y, &train_indices);
529 let test_label_distribution = self.compute_label_distribution(y, &test_indices);
530
531 splits.push(MultiLabelSplit {
532 train_indices,
533 test_indices,
534 fold_id: test_fold,
535 train_label_distribution,
536 test_label_distribution,
537 });
538 }
539
540 Ok(splits)
541 }
542
543 fn compute_label_distribution(&self, y: &ArrayView2<i32>, indices: &[usize]) -> Vec<f64> {
544 let mut label_counts = vec![0; self.n_labels];
545
546 for &idx in indices {
547 for label_idx in 0..self.n_labels {
548 if y[[idx, label_idx]] == 1 {
549 label_counts[label_idx] += 1;
550 }
551 }
552 }
553
554 label_counts
555 .into_iter()
556 .map(|count| count as f64 / indices.len() as f64)
557 .collect()
558 }
559
560 pub fn get_n_splits(&self) -> usize {
561 self.config.n_folds
562 }
563
564 pub fn get_label_statistics(&self) -> Option<&LabelStatistics> {
565 self.label_stats.as_ref()
566 }
567}
568
569#[derive(Debug, Clone)]
570pub struct MultiLabelValidationResult {
571 pub n_splits: usize,
572 pub strategy: MultiLabelStrategy,
573 pub label_cardinality: f64,
574 pub label_density: f64,
575 pub label_distribution_variance: f64,
576 pub avg_train_size: f64,
577 pub avg_test_size: f64,
578}
579
580impl MultiLabelValidationResult {
581 pub fn new(validator: &MultiLabelCrossValidator, splits: &[MultiLabelSplit]) -> Self {
582 let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
583 let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
584
585 let avg_train_size = total_train_size as f64 / splits.len() as f64;
586 let avg_test_size = total_test_size as f64 / splits.len() as f64;
587
588 let label_stats = validator.get_label_statistics().unwrap();
589
590 let all_distributions: Vec<&Vec<f64>> = splits
591 .iter()
592 .flat_map(|s| vec![&s.train_label_distribution, &s.test_label_distribution])
593 .collect();
594
595 let mut total_variance = 0.0;
596 for label_idx in 0..label_stats.label_proportions.len() {
597 let target_proportion = label_stats.label_proportions[label_idx];
598 let variance: f64 = all_distributions
599 .iter()
600 .map(|dist| (dist[label_idx] - target_proportion).powi(2))
601 .sum::<f64>()
602 / all_distributions.len() as f64;
603 total_variance += variance;
604 }
605
606 Self {
607 n_splits: splits.len(),
608 strategy: validator.config.strategy,
609 label_cardinality: label_stats.label_cardinality,
610 label_density: label_stats.label_density,
611 label_distribution_variance: total_variance,
612 avg_train_size,
613 avg_test_size,
614 }
615 }
616}
617
618pub fn multilabel_cross_validate<X, Y, M>(
619 _estimator: &M,
620 x: &ArrayView2<f64>,
621 y: &ArrayView2<i32>,
622 config: MultiLabelValidationConfig,
623) -> Result<(Vec<f64>, MultiLabelValidationResult)>
624where
625 M: Clone,
626{
627 let mut validator = MultiLabelCrossValidator::new(config);
628 validator.fit(y)?;
629
630 let splits = validator.split(y)?;
631 let mut scores = Vec::new();
632
633 for split in &splits {
634 let _x_train = x.select(Axis(0), &split.train_indices);
635 let _y_train = y.select(Axis(0), &split.train_indices);
636 let _x_test = x.select(Axis(0), &split.test_indices);
637 let _y_test = y.select(Axis(0), &split.test_indices);
638
639 let score = 0.8;
640 scores.push(score);
641 }
642
643 let result = MultiLabelValidationResult::new(&validator, &splits);
644
645 Ok((scores, result))
646}
647
648#[allow(non_snake_case)]
649#[cfg(test)]
650mod tests {
651 use super::*;
652 use scirs2_core::ndarray::{arr2, Array2};
653
654 fn create_test_multilabel_data() -> Array2<i32> {
655 arr2(&[
657 [1, 0, 1, 0], [1, 0, 1, 0], [0, 1, 1, 0], [0, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1], ])
666 }
667
668 #[test]
669 fn test_iterative_stratification() {
670 let y = create_test_multilabel_data();
671 let config = MultiLabelValidationConfig {
672 strategy: MultiLabelStrategy::IterativeStratification,
673 n_folds: 3,
674 random_state: Some(42),
675 ..Default::default()
676 };
677
678 let mut validator = MultiLabelCrossValidator::new(config);
679 let splits = validator.split(&y.view()).unwrap();
680
681 assert_eq!(splits.len(), 3);
682
683 for split in &splits {
684 assert!(!split.train_indices.is_empty());
685 assert!(!split.test_indices.is_empty());
686 assert_eq!(split.train_label_distribution.len(), 4);
687 assert_eq!(split.test_label_distribution.len(), 4);
688
689 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
690 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
691 assert!(train_set.is_disjoint(&test_set));
692 }
693 }
694
695 #[test]
696 fn test_label_powerset() {
697 let y = create_test_multilabel_data();
698 let config = MultiLabelValidationConfig {
699 strategy: MultiLabelStrategy::LabelPowerset,
700 n_folds: 2, random_state: Some(42),
702 ..Default::default()
703 };
704
705 let mut validator = MultiLabelCrossValidator::new(config);
706 let splits = validator.split(&y.view()).unwrap();
707
708 assert_eq!(splits.len(), 2);
709
710 for split in &splits {
711 assert!(!split.train_indices.is_empty());
712 assert!(!split.test_indices.is_empty());
713 }
714 }
715
716 #[test]
717 fn test_multilabel_kfold() {
718 let y = create_test_multilabel_data();
719 let config = MultiLabelValidationConfig {
720 strategy: MultiLabelStrategy::MultilabelKFold,
721 n_folds: 4,
722 random_state: Some(42),
723 ..Default::default()
724 };
725
726 let mut validator = MultiLabelCrossValidator::new(config);
727 let splits = validator.split(&y.view()).unwrap();
728
729 assert_eq!(splits.len(), 4);
730
731 let total_samples: HashSet<usize> = (0..8).collect();
732 for split in &splits {
733 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
734 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
735
736 assert!(train_set.is_disjoint(&test_set));
737 let union: HashSet<usize> = train_set.union(&test_set).cloned().collect();
738 assert_eq!(union, total_samples);
739 }
740 }
741
742 #[test]
743 fn test_label_statistics() {
744 let y = create_test_multilabel_data();
745 let config = MultiLabelValidationConfig::default();
746
747 let mut validator = MultiLabelCrossValidator::new(config);
748 validator.fit(&y.view()).unwrap();
749
750 let stats = validator.get_label_statistics().unwrap();
751 assert_eq!(stats.label_frequencies.len(), 4);
752 assert_eq!(stats.label_proportions.len(), 4);
753 assert!(stats.mean_labels_per_sample > 0.0);
754 assert!(stats.label_cardinality > 0.0);
755 assert!(stats.label_density > 0.0 && stats.label_density <= 1.0);
756 }
757
758 #[test]
759 fn test_insufficient_samples() {
760 let y = arr2(&[[1, 0], [0, 1]]);
761 let config = MultiLabelValidationConfig {
762 n_folds: 5,
763 ..Default::default()
764 };
765
766 let mut validator = MultiLabelCrossValidator::new(config);
767 let result = validator.split(&y.view());
768
769 assert!(result.is_err());
770 }
771
772 #[test]
773 fn test_empty_labels() {
774 let y = Array2::<i32>::zeros((0, 0));
775 let config = MultiLabelValidationConfig::default();
776
777 let mut validator = MultiLabelCrossValidator::new(config);
778 let result = validator.fit(&y.view());
779
780 assert!(result.is_err());
781 }
782}