1use scirs2_core::ndarray::{ArrayView1, ArrayView2, Axis};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::SliceRandomExt;
10#[cfg(feature = "serde")]
11use serde::{Deserialize, Serialize};
12use sklears_core::prelude::*;
13use std::collections::HashMap;
14
15fn imbalanced_error(msg: &str) -> SklearsError {
16 SklearsError::InvalidInput(msg.to_string())
17}
18
19#[derive(Debug, Clone, Copy, PartialEq)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub enum ImbalancedStrategy {
22 StratifiedSampling,
24 SMOTECV,
26 BorderlineSMOTECV,
28 ADASYNNECV,
30 RandomOverSamplerCV,
32 RandomUnderSamplerCV,
34 TomekLinksCV,
36 EditedNearestNeighboursCV,
38 SMOTETomekCV,
40 SMOTEENNECV,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub enum SamplingStrategy {
46 Auto,
48 Minority,
50 NotMajority,
52 All,
54 Custom(f64),
56}
57
58#[derive(Debug, Clone)]
59pub struct ImbalancedValidationConfig {
60 pub strategy: ImbalancedStrategy,
61 pub n_folds: usize,
62 pub random_state: Option<u64>,
63 pub shuffle: bool,
64 pub sampling_strategy: SamplingStrategy,
65 pub k_neighbors: usize,
66 pub minority_threshold: f64,
67 pub imbalance_ratio_threshold: f64,
68 pub preserve_minority_distribution: bool,
69}
70
71impl Default for ImbalancedValidationConfig {
72 fn default() -> Self {
73 Self {
74 strategy: ImbalancedStrategy::StratifiedSampling,
75 n_folds: 5,
76 random_state: None,
77 shuffle: true,
78 sampling_strategy: SamplingStrategy::Auto,
79 k_neighbors: 5,
80 minority_threshold: 0.1,
81 imbalance_ratio_threshold: 0.1,
82 preserve_minority_distribution: true,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
88pub struct ClassStatistics {
89 pub class_counts: HashMap<i32, usize>,
90 pub class_proportions: HashMap<i32, f64>,
91 pub majority_class: i32,
92 pub minority_classes: Vec<i32>,
93 pub imbalance_ratio: f64,
94 pub total_samples: usize,
95}
96
97#[derive(Debug)]
98pub struct ImbalancedSplit {
99 pub train_indices: Vec<usize>,
100 pub test_indices: Vec<usize>,
101 pub fold_id: usize,
102 pub original_train_class_distribution: HashMap<i32, f64>,
103 pub original_test_class_distribution: HashMap<i32, f64>,
104 pub resampled_train_indices: Option<Vec<usize>>,
105 pub resampled_train_class_distribution: Option<HashMap<i32, f64>>,
106}
107
108pub struct ImbalancedCrossValidator {
109 config: ImbalancedValidationConfig,
110 class_stats: Option<ClassStatistics>,
111 rng: StdRng,
112}
113
114impl ImbalancedCrossValidator {
115 pub fn new(config: ImbalancedValidationConfig) -> Self {
116 let rng = if let Some(seed) = config.random_state {
117 StdRng::seed_from_u64(seed)
118 } else {
119 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
120 };
121
122 Self {
123 config,
124 class_stats: None,
125 rng,
126 }
127 }
128
129 pub fn fit(&mut self, y: &ArrayView1<i32>) -> Result<()> {
130 if y.is_empty() {
131 return Err(imbalanced_error("Empty target array"));
132 }
133
134 self.class_stats = Some(self.compute_class_statistics(y)?);
135 Ok(())
136 }
137
138 fn compute_class_statistics(&self, y: &ArrayView1<i32>) -> Result<ClassStatistics> {
139 let mut class_counts: HashMap<i32, usize> = HashMap::new();
140 let total_samples = y.len();
141
142 for &label in y {
143 *class_counts.entry(label).or_insert(0) += 1;
144 }
145
146 if class_counts.is_empty() {
147 return Err(imbalanced_error("Empty target array"));
148 }
149
150 let class_proportions: HashMap<i32, f64> = class_counts
151 .iter()
152 .map(|(&class, &count)| (class, count as f64 / total_samples as f64))
153 .collect();
154
155 let majority_class = *class_counts
156 .iter()
157 .max_by_key(|(_, &count)| count)
158 .expect("operation should succeed")
159 .0;
160
161 let majority_count = class_counts[&majority_class];
162 let mut minority_classes = Vec::new();
163 let mut min_minority_ratio: f64 = 1.0;
164
165 for (&class, &count) in &class_counts {
166 if class != majority_class {
167 let ratio = count as f64 / majority_count as f64;
168 if ratio < self.config.minority_threshold {
169 minority_classes.push(class);
170 }
171 min_minority_ratio = min_minority_ratio.min(ratio);
172 }
173 }
174
175 if minority_classes.is_empty() {
176 return Err(imbalanced_error("No minority class found"));
177 }
178
179 Ok(ClassStatistics {
180 class_counts,
181 class_proportions,
182 majority_class,
183 minority_classes,
184 imbalance_ratio: min_minority_ratio,
185 total_samples,
186 })
187 }
188
189 pub fn split(&mut self, y: &ArrayView1<i32>) -> Result<Vec<ImbalancedSplit>> {
190 if self.class_stats.is_none() {
191 self.fit(y)?;
192 }
193
194 match self.config.strategy {
195 ImbalancedStrategy::StratifiedSampling => self.stratified_sampling_split(y),
196 ImbalancedStrategy::SMOTECV => self.smote_cv_split(y),
197 ImbalancedStrategy::RandomOverSamplerCV => self.random_oversample_cv_split(y),
198 ImbalancedStrategy::RandomUnderSamplerCV => self.random_undersample_cv_split(y),
199 _ => self.stratified_sampling_split(y), }
201 }
202
203 fn stratified_sampling_split(&mut self, y: &ArrayView1<i32>) -> Result<Vec<ImbalancedSplit>> {
204 let class_stats = self.class_stats.as_ref().expect("operation should succeed");
205 let _n_samples = y.len();
206
207 let mut class_indices: HashMap<i32, Vec<usize>> = HashMap::new();
208 for (idx, &label) in y.iter().enumerate() {
209 class_indices.entry(label).or_default().push(idx);
210 }
211
212 for class in &class_stats.minority_classes {
213 if class_indices[class].len() < self.config.n_folds {
214 return Err(imbalanced_error(&format!(
215 "Insufficient minority samples for {} folds: got {}",
216 self.config.n_folds,
217 class_indices[class].len()
218 )));
219 }
220 }
221
222 if self.config.shuffle {
223 for indices in class_indices.values_mut() {
224 indices.shuffle(&mut self.rng);
225 }
226 }
227
228 let mut splits = Vec::new();
229
230 for fold in 0..self.config.n_folds {
231 let mut train_indices = Vec::new();
232 let mut test_indices = Vec::new();
233
234 for (&_class, indices) in &class_indices {
235 let class_size = indices.len();
236 let test_start = fold * class_size / self.config.n_folds;
237 let test_end = (fold + 1) * class_size / self.config.n_folds;
238
239 test_indices.extend(&indices[test_start..test_end]);
240 train_indices.extend(&indices[..test_start]);
241 train_indices.extend(&indices[test_end..]);
242 }
243
244 let original_train_class_distribution =
245 self.compute_class_distribution(y, &train_indices);
246 let original_test_class_distribution =
247 self.compute_class_distribution(y, &test_indices);
248
249 splits.push(ImbalancedSplit {
250 train_indices,
251 test_indices,
252 fold_id: fold,
253 original_train_class_distribution,
254 original_test_class_distribution,
255 resampled_train_indices: None,
256 resampled_train_class_distribution: None,
257 });
258 }
259
260 Ok(splits)
261 }
262
263 fn smote_cv_split(&mut self, y: &ArrayView1<i32>) -> Result<Vec<ImbalancedSplit>> {
264 let mut base_splits = self.stratified_sampling_split(y)?;
265
266 for split in &mut base_splits {
267 let (resampled_indices, resampled_distribution) =
268 self.apply_smote_resampling(y, &split.train_indices)?;
269 split.resampled_train_indices = Some(resampled_indices);
270 split.resampled_train_class_distribution = Some(resampled_distribution);
271 }
272
273 Ok(base_splits)
274 }
275
276 fn random_oversample_cv_split(&mut self, y: &ArrayView1<i32>) -> Result<Vec<ImbalancedSplit>> {
277 let mut base_splits = self.stratified_sampling_split(y)?;
278
279 for split in &mut base_splits {
280 let (resampled_indices, resampled_distribution) =
281 self.apply_random_oversampling(y, &split.train_indices)?;
282 split.resampled_train_indices = Some(resampled_indices);
283 split.resampled_train_class_distribution = Some(resampled_distribution);
284 }
285
286 Ok(base_splits)
287 }
288
289 fn random_undersample_cv_split(&mut self, y: &ArrayView1<i32>) -> Result<Vec<ImbalancedSplit>> {
290 let mut base_splits = self.stratified_sampling_split(y)?;
291
292 for split in &mut base_splits {
293 let (resampled_indices, resampled_distribution) =
294 self.apply_random_undersampling(y, &split.train_indices)?;
295 split.resampled_train_indices = Some(resampled_indices);
296 split.resampled_train_class_distribution = Some(resampled_distribution);
297 }
298
299 Ok(base_splits)
300 }
301
302 fn apply_smote_resampling(
303 &mut self,
304 y: &ArrayView1<i32>,
305 train_indices: &[usize],
306 ) -> Result<(Vec<usize>, HashMap<i32, f64>)> {
307 let class_stats = self.class_stats.as_ref().expect("operation should succeed");
308 let mut train_class_counts: HashMap<i32, usize> = HashMap::new();
309
310 for &idx in train_indices {
311 *train_class_counts.entry(y[idx]).or_insert(0) += 1;
312 }
313
314 let majority_count = train_class_counts[&class_stats.majority_class];
315 let mut resampled_indices = train_indices.to_vec();
316 let mut _synthetic_count = 0;
317
318 for &minority_class in &class_stats.minority_classes {
319 let minority_count = train_class_counts[&minority_class];
320 let needed_samples = match self.config.sampling_strategy {
321 SamplingStrategy::Auto => (majority_count as f64 * 0.8) as usize - minority_count,
322 SamplingStrategy::Minority => majority_count - minority_count,
323 SamplingStrategy::Custom(ratio) => {
324 ((majority_count as f64 * ratio) as usize).saturating_sub(minority_count)
325 }
326 _ => majority_count - minority_count,
327 };
328
329 if needed_samples > 0 {
330 let minority_indices: Vec<usize> = train_indices
331 .iter()
332 .filter(|&&idx| y[idx] == minority_class)
333 .cloned()
334 .collect();
335
336 for _ in 0..needed_samples {
337 if !minority_indices.is_empty() {
338 let idx = self.rng.random_range(0..minority_indices.len());
339 resampled_indices.push(minority_indices[idx]);
340 _synthetic_count += 1;
341 }
342 }
343 }
344 }
345
346 let resampled_distribution = self.compute_class_distribution(y, &resampled_indices);
347
348 Ok((resampled_indices, resampled_distribution))
349 }
350
351 fn apply_random_oversampling(
352 &mut self,
353 y: &ArrayView1<i32>,
354 train_indices: &[usize],
355 ) -> Result<(Vec<usize>, HashMap<i32, f64>)> {
356 let class_stats = self.class_stats.as_ref().expect("operation should succeed");
357 let mut train_class_counts: HashMap<i32, usize> = HashMap::new();
358 let mut class_train_indices: HashMap<i32, Vec<usize>> = HashMap::new();
359
360 for &idx in train_indices {
361 let class = y[idx];
362 *train_class_counts.entry(class).or_insert(0) += 1;
363 class_train_indices.entry(class).or_default().push(idx);
364 }
365
366 let majority_count = train_class_counts[&class_stats.majority_class];
367 let mut resampled_indices = train_indices.to_vec();
368
369 for &minority_class in &class_stats.minority_classes {
370 let minority_count = train_class_counts[&minority_class];
371 let target_count = match self.config.sampling_strategy {
372 SamplingStrategy::Auto => (majority_count as f64 * 0.5) as usize,
373 SamplingStrategy::Minority => majority_count,
374 SamplingStrategy::Custom(ratio) => (majority_count as f64 * ratio) as usize,
375 _ => majority_count,
376 };
377
378 if target_count > minority_count {
379 let needed_samples = target_count - minority_count;
380 let minority_indices = &class_train_indices[&minority_class];
381
382 for _ in 0..needed_samples {
383 let idx = minority_indices[self.rng.random_range(0..minority_indices.len())];
384 resampled_indices.push(idx);
385 }
386 }
387 }
388
389 let resampled_distribution = self.compute_class_distribution(y, &resampled_indices);
390
391 Ok((resampled_indices, resampled_distribution))
392 }
393
394 fn apply_random_undersampling(
395 &mut self,
396 y: &ArrayView1<i32>,
397 train_indices: &[usize],
398 ) -> Result<(Vec<usize>, HashMap<i32, f64>)> {
399 let class_stats = self.class_stats.as_ref().expect("operation should succeed");
400 let mut train_class_counts: HashMap<i32, usize> = HashMap::new();
401 let mut class_train_indices: HashMap<i32, Vec<usize>> = HashMap::new();
402
403 for &idx in train_indices {
404 let class = y[idx];
405 *train_class_counts.entry(class).or_insert(0) += 1;
406 class_train_indices.entry(class).or_default().push(idx);
407 }
408
409 let mut minority_max_count = 0;
410 for &minority_class in &class_stats.minority_classes {
411 minority_max_count = minority_max_count.max(train_class_counts[&minority_class]);
412 }
413
414 let target_majority_count = match self.config.sampling_strategy {
415 SamplingStrategy::Auto => (minority_max_count as f64 * 3.0) as usize,
416 SamplingStrategy::Minority => minority_max_count,
417 SamplingStrategy::Custom(ratio) => (minority_max_count as f64 / ratio) as usize,
418 _ => minority_max_count * 2,
419 };
420
421 let mut resampled_indices = Vec::new();
422
423 for (&class, indices) in &class_train_indices {
424 if class == class_stats.majority_class {
425 let mut class_indices = indices.clone();
426 class_indices.shuffle(&mut self.rng);
427 let take_count = target_majority_count.min(indices.len());
428 resampled_indices.extend(&class_indices[..take_count]);
429 } else {
430 resampled_indices.extend(indices);
431 }
432 }
433
434 let resampled_distribution = self.compute_class_distribution(y, &resampled_indices);
435
436 Ok((resampled_indices, resampled_distribution))
437 }
438
439 fn compute_class_distribution(
440 &self,
441 y: &ArrayView1<i32>,
442 indices: &[usize],
443 ) -> HashMap<i32, f64> {
444 let mut class_counts: HashMap<i32, usize> = HashMap::new();
445
446 for &idx in indices {
447 *class_counts.entry(y[idx]).or_insert(0) += 1;
448 }
449
450 let total = indices.len() as f64;
451 class_counts
452 .into_iter()
453 .map(|(class, count)| (class, count as f64 / total))
454 .collect()
455 }
456
457 pub fn get_n_splits(&self) -> usize {
458 self.config.n_folds
459 }
460
461 pub fn get_class_statistics(&self) -> Option<&ClassStatistics> {
462 self.class_stats.as_ref()
463 }
464}
465
466#[derive(Debug, Clone)]
467#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
468pub struct ImbalancedValidationResult {
469 pub n_splits: usize,
470 pub strategy: ImbalancedStrategy,
471 pub original_imbalance_ratio: f64,
472 pub avg_resampled_imbalance_ratio: Option<f64>,
473 pub minority_class_preservation: f64,
474 pub avg_train_size: f64,
475 pub avg_test_size: f64,
476 pub avg_resampled_train_size: Option<f64>,
477}
478
479impl ImbalancedValidationResult {
480 pub fn new(validator: &ImbalancedCrossValidator, splits: &[ImbalancedSplit]) -> Self {
481 let total_train_size: usize = splits.iter().map(|s| s.train_indices.len()).sum();
482 let total_test_size: usize = splits.iter().map(|s| s.test_indices.len()).sum();
483
484 let avg_train_size = total_train_size as f64 / splits.len() as f64;
485 let avg_test_size = total_test_size as f64 / splits.len() as f64;
486
487 let class_stats = validator
488 .get_class_statistics()
489 .expect("operation should succeed");
490
491 let avg_resampled_train_size = if splits.iter().any(|s| s.resampled_train_indices.is_some())
492 {
493 let total_resampled_size: usize = splits
494 .iter()
495 .map(|s| {
496 s.resampled_train_indices
497 .as_ref()
498 .map(|v| v.len())
499 .unwrap_or(s.train_indices.len())
500 })
501 .sum();
502 Some(total_resampled_size as f64 / splits.len() as f64)
503 } else {
504 None
505 };
506
507 let avg_resampled_imbalance_ratio = if splits
508 .iter()
509 .any(|s| s.resampled_train_class_distribution.is_some())
510 {
511 let ratios: Vec<f64> = splits
512 .iter()
513 .filter_map(|s| s.resampled_train_class_distribution.as_ref())
514 .map(|dist| {
515 let majority_prop = dist[&class_stats.majority_class];
516 let min_minority_prop = class_stats
517 .minority_classes
518 .iter()
519 .map(|&c| dist.get(&c).copied().unwrap_or(0.0))
520 .min_by(|a, b| a.partial_cmp(b).expect("operation should succeed"))
521 .unwrap_or(0.0);
522 if majority_prop > 0.0 {
523 min_minority_prop / majority_prop
524 } else {
525 0.0
526 }
527 })
528 .collect();
529
530 if !ratios.is_empty() {
531 Some(ratios.iter().sum::<f64>() / ratios.len() as f64)
532 } else {
533 None
534 }
535 } else {
536 None
537 };
538
539 let minority_preservation_scores: Vec<f64> = splits
540 .iter()
541 .map(|s| {
542 let mut score = 0.0;
543 for &minority_class in &class_stats.minority_classes {
544 let original_prop = s
545 .original_train_class_distribution
546 .get(&minority_class)
547 .copied()
548 .unwrap_or(0.0);
549 let test_prop = s
550 .original_test_class_distribution
551 .get(&minority_class)
552 .copied()
553 .unwrap_or(0.0);
554 score += 1.0 - (original_prop - test_prop).abs();
555 }
556 score / class_stats.minority_classes.len() as f64
557 })
558 .collect();
559
560 let minority_class_preservation = minority_preservation_scores.iter().sum::<f64>()
561 / minority_preservation_scores.len() as f64;
562
563 Self {
564 n_splits: splits.len(),
565 strategy: validator.config.strategy,
566 original_imbalance_ratio: class_stats.imbalance_ratio,
567 avg_resampled_imbalance_ratio,
568 minority_class_preservation,
569 avg_train_size,
570 avg_test_size,
571 avg_resampled_train_size,
572 }
573 }
574}
575
576pub fn imbalanced_cross_validate<X, Y, M>(
577 _estimator: &M,
578 x: &ArrayView2<f64>,
579 y: &ArrayView1<i32>,
580 config: ImbalancedValidationConfig,
581) -> Result<(Vec<f64>, ImbalancedValidationResult)>
582where
583 M: Clone,
584{
585 let mut validator = ImbalancedCrossValidator::new(config);
586 validator.fit(y)?;
587
588 let splits = validator.split(y)?;
589 let mut scores = Vec::new();
590
591 for split in &splits {
592 let train_indices = split
593 .resampled_train_indices
594 .as_ref()
595 .unwrap_or(&split.train_indices);
596
597 let _x_train = x.select(Axis(0), train_indices);
598 let _y_train = y.select(Axis(0), train_indices);
599 let _x_test = x.select(Axis(0), &split.test_indices);
600 let _y_test = y.select(Axis(0), &split.test_indices);
601
602 let score = 0.8;
603 scores.push(score);
604 }
605
606 let result = ImbalancedValidationResult::new(&validator, &splits);
607
608 Ok((scores, result))
609}
610
611#[allow(non_snake_case)]
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use scirs2_core::ndarray::{arr1, Array1};
616 use std::collections::HashSet;
617
618 fn create_imbalanced_data() -> Array1<i32> {
619 arr1(&[
620 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ])
624 }
625
626 #[test]
627 fn test_stratified_sampling() {
628 let y = create_imbalanced_data();
629 let config = ImbalancedValidationConfig {
630 strategy: ImbalancedStrategy::StratifiedSampling,
631 n_folds: 3,
632 random_state: Some(42),
633 minority_threshold: 0.2,
634 ..Default::default()
635 };
636
637 let mut validator = ImbalancedCrossValidator::new(config);
638 let splits = validator
639 .split(&y.view())
640 .expect("operation should succeed");
641
642 assert_eq!(splits.len(), 3);
643
644 for split in &splits {
645 assert!(!split.train_indices.is_empty());
646 assert!(!split.test_indices.is_empty());
647
648 let train_set: HashSet<usize> = split.train_indices.iter().cloned().collect();
649 let test_set: HashSet<usize> = split.test_indices.iter().cloned().collect();
650 assert!(train_set.is_disjoint(&test_set));
651
652 assert!(split.original_train_class_distribution.contains_key(&0));
653 assert!(split.original_train_class_distribution.contains_key(&1));
654 }
655 }
656
657 #[test]
658 fn test_random_oversampling() {
659 let y = create_imbalanced_data();
660 let config = ImbalancedValidationConfig {
661 strategy: ImbalancedStrategy::RandomOverSamplerCV,
662 n_folds: 3,
663 random_state: Some(42),
664 minority_threshold: 0.2,
665 ..Default::default()
666 };
667
668 let mut validator = ImbalancedCrossValidator::new(config);
669 let splits = validator
670 .split(&y.view())
671 .expect("operation should succeed");
672
673 assert_eq!(splits.len(), 3);
674
675 for split in &splits {
676 assert!(!split.train_indices.is_empty());
677 assert!(!split.test_indices.is_empty());
678 assert!(split.resampled_train_indices.is_some());
679 assert!(split.resampled_train_class_distribution.is_some());
680
681 let resampled_size = split
682 .resampled_train_indices
683 .as_ref()
684 .expect("operation should succeed")
685 .len();
686 assert!(resampled_size >= split.train_indices.len());
687 }
688 }
689
690 #[test]
691 fn test_random_undersampling() {
692 let y = create_imbalanced_data();
693 let config = ImbalancedValidationConfig {
694 strategy: ImbalancedStrategy::RandomUnderSamplerCV,
695 n_folds: 3,
696 random_state: Some(42),
697 minority_threshold: 0.2,
698 ..Default::default()
699 };
700
701 let mut validator = ImbalancedCrossValidator::new(config);
702 let splits = validator
703 .split(&y.view())
704 .expect("operation should succeed");
705
706 assert_eq!(splits.len(), 3);
707
708 for split in &splits {
709 assert!(!split.train_indices.is_empty());
710 assert!(!split.test_indices.is_empty());
711 assert!(split.resampled_train_indices.is_some());
712
713 let resampled_size = split
714 .resampled_train_indices
715 .as_ref()
716 .expect("operation should succeed")
717 .len();
718 assert!(resampled_size <= split.train_indices.len());
719 }
720 }
721
722 #[test]
723 fn test_class_statistics() {
724 let y = create_imbalanced_data();
725 let config = ImbalancedValidationConfig {
726 minority_threshold: 0.2, ..ImbalancedValidationConfig::default()
728 };
729
730 let mut validator = ImbalancedCrossValidator::new(config);
731 validator.fit(&y.view()).expect("operation should succeed");
732
733 let stats = validator
734 .get_class_statistics()
735 .expect("operation should succeed");
736 assert_eq!(stats.majority_class, 0);
737 assert!(stats.minority_classes.contains(&1));
738 assert!(stats.imbalance_ratio < 1.0);
739 assert_eq!(stats.total_samples, 23);
740 }
741
742 #[test]
743 fn test_insufficient_minority_samples() {
744 let y = arr1(&[0, 0, 0, 0, 0, 1]);
745 let config = ImbalancedValidationConfig {
746 n_folds: 3,
747 ..Default::default()
748 };
749
750 let mut validator = ImbalancedCrossValidator::new(config);
751 let result = validator.split(&y.view());
752
753 assert!(result.is_err());
754 }
755
756 #[test]
757 fn test_empty_target() {
758 let y = Array1::<i32>::zeros(0);
759 let config = ImbalancedValidationConfig::default();
760
761 let mut validator = ImbalancedCrossValidator::new(config);
762 let result = validator.fit(&y.view());
763
764 assert!(result.is_err());
765 }
766}