1use scirs2_core::ndarray::{Array2, ArrayView2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::{Random, Rng};
9use sklears_core::{
10 error::Result as SklResult,
11 traits::{Fit, Transform},
12 types::Float,
13};
14use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct ImputationBenchmark {
20 pub method_name: String,
22 pub dataset_name: String,
24 pub missing_rate: f64,
26 pub missing_pattern: String,
28 pub rmse: f64,
30 pub mae: f64,
32 pub execution_time: Duration,
34 pub memory_usage: Option<usize>,
36 pub convergence_iterations: Option<usize>,
38}
39
40#[derive(Debug, Clone)]
42pub struct ImputationComparison {
43 pub benchmarks: Vec<ImputationBenchmark>,
45 pub best_rmse_method: String,
47 pub best_mae_method: String,
49 pub fastest_method: String,
51 pub accuracy_rankings: HashMap<String, usize>,
53 pub speed_rankings: HashMap<String, usize>,
55}
56
57#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59pub enum MissingPattern {
60 MCAR { missing_rate: f64 },
62 MAR {
64 missing_rate: f64,
65 dependency_strength: f64,
66 },
67 MNAR { missing_rate: f64, threshold: f64 },
69 Block {
71 block_size: usize,
72 missing_rate: f64,
73 },
74 Monotone { missing_rate: f64 },
76}
77
78pub struct BenchmarkDatasetGenerator {
80 n_samples: usize,
81 n_features: usize,
82 noise_level: f64,
83 correlation_strength: f64,
84 random_state: Option<u64>,
85}
86
87impl BenchmarkDatasetGenerator {
88 pub fn new(n_samples: usize, n_features: usize) -> Self {
90 Self {
91 n_samples,
92 n_features,
93 noise_level: 0.1,
94 correlation_strength: 0.5,
95 random_state: None,
96 }
97 }
98
99 pub fn noise_level(mut self, noise_level: f64) -> Self {
101 self.noise_level = noise_level;
102 self
103 }
104
105 pub fn correlation_strength(mut self, correlation_strength: f64) -> Self {
107 self.correlation_strength = correlation_strength;
108 self
109 }
110
111 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
113 self.random_state = random_state;
114 self
115 }
116
117 pub fn generate_correlated_data(&self) -> SklResult<Array2<f64>> {
119 let mut rng = if let Some(_seed) = self.random_state {
120 Random::default()
121 } else {
122 Random::default()
123 };
124
125 let mut data = Array2::zeros((self.n_samples, self.n_features));
126
127 for i in 0..self.n_samples {
129 data[[i, 0]] = rng.gen_range(-3.0..3.0);
130 }
131
132 for j in 1..self.n_features {
134 let correlation = self.correlation_strength;
135 for i in 0..self.n_samples {
136 let base_value = data[[i, 0]];
137 let noise = rng.gen_range(-self.noise_level..self.noise_level);
138 data[[i, j]] = correlation * base_value
139 + (1.0 - correlation) * rng.gen_range(-2.0..2.0)
140 + noise;
141 }
142 }
143
144 Ok(data)
145 }
146
147 pub fn generate_linear_data(&self) -> SklResult<Array2<f64>> {
149 let mut rng = if let Some(_seed) = self.random_state {
150 Random::default()
151 } else {
152 Random::default()
153 };
154
155 let mut data = Array2::zeros((self.n_samples, self.n_features));
156
157 for i in 0..self.n_samples {
159 data[[i, 0]] = rng.gen_range(-5.0..5.0);
161
162 for j in 1..self.n_features {
164 let coef = (j as f64) * 0.5;
165 let noise = rng.gen_range(-self.noise_level..self.noise_level);
166 data[[i, j]] = coef * data[[i, 0]] + noise;
167 }
168 }
169
170 Ok(data)
171 }
172
173 pub fn generate_nonlinear_data(&self) -> SklResult<Array2<f64>> {
175 let mut rng = if let Some(_seed) = self.random_state {
176 Random::default()
177 } else {
178 Random::default()
179 };
180
181 let mut data = Array2::zeros((self.n_samples, self.n_features));
182
183 for i in 0..self.n_samples {
184 let x: f64 = rng.gen_range(-2.0..2.0);
185 data[[i, 0]] = x;
186
187 data[[i, 1]] = x.powi(2) + rng.gen_range(-self.noise_level..self.noise_level);
189
190 if self.n_features > 2 {
191 data[[i, 2]] = (x * 1.5).sin() + rng.gen_range(-self.noise_level..self.noise_level);
192 }
193
194 if self.n_features > 3 {
195 data[[i, 3]] = (x.powi(2) + x).exp() / 10.0
196 + rng.gen_range(-self.noise_level..self.noise_level);
197 }
198
199 for j in 4..self.n_features {
201 let noise = rng.gen_range(-self.noise_level..self.noise_level);
202 data[[i, j]] = (x + (j as f64) * 0.2).cos() + noise;
203 }
204 }
205
206 Ok(data)
207 }
208}
209
210pub struct MissingPatternGenerator {
212 random_state: Option<u64>,
213}
214
215impl MissingPatternGenerator {
216 pub fn new() -> Self {
218 Self { random_state: None }
219 }
220
221 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
223 self.random_state = random_state;
224 self
225 }
226
227 pub fn introduce_missing(
229 &self,
230 data: &Array2<f64>,
231 pattern: &MissingPattern,
232 ) -> SklResult<(Array2<f64>, Array2<bool>)> {
233 let mut rng = if let Some(_seed) = self.random_state {
234 Random::default()
235 } else {
236 Random::default()
237 };
238
239 let (n_samples, n_features) = data.dim();
240 let mut data_with_missing = data.clone();
241 let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
242
243 match pattern {
244 MissingPattern::MCAR { missing_rate } => {
245 self.introduce_mcar(
246 &mut data_with_missing,
247 &mut missing_mask,
248 *missing_rate,
249 &mut rng,
250 )?;
251 }
252 MissingPattern::MAR {
253 missing_rate,
254 dependency_strength,
255 } => {
256 self.introduce_mar(
257 data,
258 &mut data_with_missing,
259 &mut missing_mask,
260 *missing_rate,
261 *dependency_strength,
262 &mut rng,
263 )?;
264 }
265 MissingPattern::MNAR {
266 missing_rate,
267 threshold,
268 } => {
269 self.introduce_mnar(
270 data,
271 &mut data_with_missing,
272 &mut missing_mask,
273 *missing_rate,
274 *threshold,
275 &mut rng,
276 )?;
277 }
278 MissingPattern::Block {
279 block_size,
280 missing_rate,
281 } => {
282 self.introduce_block(
283 &mut data_with_missing,
284 &mut missing_mask,
285 *block_size,
286 *missing_rate,
287 &mut rng,
288 )?;
289 }
290 MissingPattern::Monotone { missing_rate } => {
291 self.introduce_monotone(
292 &mut data_with_missing,
293 &mut missing_mask,
294 *missing_rate,
295 &mut rng,
296 )?;
297 }
298 }
299
300 Ok((data_with_missing, missing_mask))
301 }
302
303 fn introduce_mcar(
304 &self,
305 data: &mut Array2<f64>,
306 missing_mask: &mut Array2<bool>,
307 missing_rate: f64,
308 rng: &mut Random,
309 ) -> SklResult<()> {
310 let total_elements = data.len();
311 let n_missing = (total_elements as f64 * missing_rate) as usize;
312
313 let mut positions: Vec<(usize, usize)> = Vec::new();
314 for i in 0..data.nrows() {
315 for j in 0..data.ncols() {
316 positions.push((i, j));
317 }
318 }
319
320 positions.shuffle(rng);
321
322 for &(i, j) in positions.iter().take(n_missing) {
323 data[[i, j]] = f64::NAN;
324 missing_mask[[i, j]] = true;
325 }
326
327 Ok(())
328 }
329
330 fn introduce_mar(
331 &self,
332 original_data: &Array2<f64>,
333 data: &mut Array2<f64>,
334 missing_mask: &mut Array2<bool>,
335 missing_rate: f64,
336 dependency_strength: f64,
337 rng: &mut Random,
338 ) -> SklResult<()> {
339 let (n_samples, n_features) = data.dim();
340
341 if n_features < 2 {
342 return self.introduce_mcar(data, missing_mask, missing_rate, rng);
343 }
344
345 let column_0_median = {
347 let mut sorted: Vec<f64> = original_data.column(0).iter().cloned().collect();
348 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
349 sorted[sorted.len() / 2]
350 };
351
352 for i in 0..n_samples {
353 for j in 1..n_features {
354 let base_prob = missing_rate;
355 let prob_adjustment = if original_data[[i, 0]] > column_0_median {
356 dependency_strength
357 } else {
358 -dependency_strength
359 };
360
361 let prob_missing = (base_prob + prob_adjustment).clamp(0.0, 1.0);
362
363 if rng.gen::<f64>() < prob_missing {
364 data[[i, j]] = f64::NAN;
365 missing_mask[[i, j]] = true;
366 }
367 }
368 }
369
370 Ok(())
371 }
372
373 fn introduce_mnar(
374 &self,
375 original_data: &Array2<f64>,
376 data: &mut Array2<f64>,
377 missing_mask: &mut Array2<bool>,
378 missing_rate: f64,
379 threshold: f64,
380 rng: &mut Random,
381 ) -> SklResult<()> {
382 let (n_samples, n_features) = data.dim();
383
384 for j in 0..n_features {
385 let column_values: Vec<f64> = original_data.column(j).iter().cloned().collect();
386 let column_threshold = {
387 let mut sorted = column_values.clone();
388 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
389 sorted[(sorted.len() as f64 * threshold) as usize]
390 };
391
392 for i in 0..n_samples {
393 let base_prob = missing_rate;
395 let prob_missing = if original_data[[i, j]] > column_threshold {
396 base_prob * 2.0
397 } else {
398 base_prob * 0.5
399 };
400
401 if rng.gen::<f64>() < prob_missing.min(1.0) {
402 data[[i, j]] = f64::NAN;
403 missing_mask[[i, j]] = true;
404 }
405 }
406 }
407
408 Ok(())
409 }
410
411 fn introduce_block(
412 &self,
413 data: &mut Array2<f64>,
414 missing_mask: &mut Array2<bool>,
415 block_size: usize,
416 missing_rate: f64,
417 rng: &mut Random,
418 ) -> SklResult<()> {
419 let (n_samples, n_features) = data.dim();
420 let n_blocks =
421 ((n_samples * n_features) as f64 * missing_rate / block_size as f64) as usize;
422
423 for _ in 0..n_blocks {
424 let start_i = rng.gen_range(0..n_samples);
425 let start_j = rng.gen_range(0..n_features);
426
427 let block_height = (block_size as f64).sqrt() as usize;
428 let block_width = block_size / block_height.max(1);
429
430 for di in 0..block_height {
431 for dj in 0..block_width {
432 let i = (start_i + di) % n_samples;
433 let j = (start_j + dj) % n_features;
434 data[[i, j]] = f64::NAN;
435 missing_mask[[i, j]] = true;
436 }
437 }
438 }
439
440 Ok(())
441 }
442
443 fn introduce_monotone(
444 &self,
445 data: &mut Array2<f64>,
446 missing_mask: &mut Array2<bool>,
447 missing_rate: f64,
448 rng: &mut Random,
449 ) -> SklResult<()> {
450 let (n_samples, n_features) = data.dim();
451
452 if n_features == 0 {
453 return Ok(());
454 }
455
456 let mut samples_to_affect: Vec<usize> = (0..n_samples).collect();
457 samples_to_affect.shuffle(rng);
458 let n_affected = (n_samples as f64 * missing_rate) as usize;
459
460 for &sample_idx in samples_to_affect.iter().take(n_affected) {
461 let start_feature = rng.gen_range(0..n_features);
463 for j in start_feature..n_features {
464 data[[sample_idx, j]] = f64::NAN;
465 missing_mask[[sample_idx, j]] = true;
466 }
467 }
468
469 Ok(())
470 }
471}
472
473impl Default for MissingPatternGenerator {
474 fn default() -> Self {
475 Self::new()
476 }
477}
478
479pub struct AccuracyMetrics;
481
482impl AccuracyMetrics {
483 pub fn rmse(
485 true_values: &Array2<f64>,
486 imputed_values: &Array2<f64>,
487 missing_mask: &Array2<bool>,
488 ) -> f64 {
489 let mut sum_squared_diff = 0.0;
490 let mut count = 0;
491
492 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
493 if is_missing {
494 let diff = true_values[[i, j]] - imputed_values[[i, j]];
495 sum_squared_diff += diff * diff;
496 count += 1;
497 }
498 }
499
500 if count > 0 {
501 (sum_squared_diff / count as f64).sqrt()
502 } else {
503 0.0
504 }
505 }
506
507 pub fn mae(
509 true_values: &Array2<f64>,
510 imputed_values: &Array2<f64>,
511 missing_mask: &Array2<bool>,
512 ) -> f64 {
513 let mut sum_abs_diff = 0.0;
514 let mut count = 0;
515
516 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
517 if is_missing {
518 let diff = (true_values[[i, j]] - imputed_values[[i, j]]).abs();
519 sum_abs_diff += diff;
520 count += 1;
521 }
522 }
523
524 if count > 0 {
525 sum_abs_diff / count as f64
526 } else {
527 0.0
528 }
529 }
530
531 pub fn bias(
533 true_values: &Array2<f64>,
534 imputed_values: &Array2<f64>,
535 missing_mask: &Array2<bool>,
536 ) -> f64 {
537 let mut sum_diff = 0.0;
538 let mut count = 0;
539
540 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
541 if is_missing {
542 let diff = imputed_values[[i, j]] - true_values[[i, j]];
543 sum_diff += diff;
544 count += 1;
545 }
546 }
547
548 if count > 0 {
549 sum_diff / count as f64
550 } else {
551 0.0
552 }
553 }
554
555 pub fn r_squared(
557 true_values: &Array2<f64>,
558 imputed_values: &Array2<f64>,
559 missing_mask: &Array2<bool>,
560 ) -> f64 {
561 let mut missing_true_values = Vec::new();
562 let mut missing_imputed_values = Vec::new();
563
564 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
565 if is_missing {
566 missing_true_values.push(true_values[[i, j]]);
567 missing_imputed_values.push(imputed_values[[i, j]]);
568 }
569 }
570
571 if missing_true_values.is_empty() {
572 return 1.0;
573 }
574
575 let true_mean = missing_true_values.iter().sum::<f64>() / missing_true_values.len() as f64;
576
577 let ss_tot: f64 = missing_true_values
578 .iter()
579 .map(|&x| (x - true_mean).powi(2))
580 .sum();
581
582 let ss_res: f64 = missing_true_values
583 .iter()
584 .zip(missing_imputed_values.iter())
585 .map(|(&true_val, &imputed_val)| (true_val - imputed_val).powi(2))
586 .sum();
587
588 if ss_tot == 0.0 {
589 1.0
590 } else {
591 1.0 - (ss_res / ss_tot)
592 }
593 }
594}
595
596pub struct BenchmarkSuite {
598 datasets: Vec<(String, Array2<f64>)>,
599 missing_patterns: Vec<(String, MissingPattern)>,
600 random_state: Option<u64>,
601}
602
603impl BenchmarkSuite {
604 pub fn new() -> Self {
606 Self {
607 datasets: Vec::new(),
608 missing_patterns: Vec::new(),
609 random_state: None,
610 }
611 }
612
613 pub fn add_dataset(mut self, name: String, data: Array2<f64>) -> Self {
615 self.datasets.push((name, data));
616 self
617 }
618
619 pub fn add_missing_pattern(mut self, name: String, pattern: MissingPattern) -> Self {
621 self.missing_patterns.push((name, pattern));
622 self
623 }
624
625 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
627 self.random_state = random_state;
628 self
629 }
630
631 pub fn add_standard_datasets(mut self) -> Self {
633 let linear_gen = BenchmarkDatasetGenerator::new(100, 4)
635 .correlation_strength(0.7)
636 .noise_level(0.1)
637 .random_state(Some(42));
638
639 if let Ok(linear_data) = linear_gen.generate_linear_data() {
640 self.datasets
641 .push(("linear_100x4".to_string(), linear_data));
642 }
643
644 let nonlinear_gen = BenchmarkDatasetGenerator::new(80, 3)
646 .noise_level(0.2)
647 .random_state(Some(123));
648
649 if let Ok(nonlinear_data) = nonlinear_gen.generate_nonlinear_data() {
650 self.datasets
651 .push(("nonlinear_80x3".to_string(), nonlinear_data));
652 }
653
654 let correlated_gen = BenchmarkDatasetGenerator::new(120, 5)
656 .correlation_strength(0.8)
657 .noise_level(0.05)
658 .random_state(Some(456));
659
660 if let Ok(correlated_data) = correlated_gen.generate_correlated_data() {
661 self.datasets
662 .push(("correlated_120x5".to_string(), correlated_data));
663 }
664
665 self
666 }
667
668 pub fn add_standard_patterns(mut self) -> Self {
670 self.missing_patterns.push((
671 "MCAR_15%".to_string(),
672 MissingPattern::MCAR { missing_rate: 0.15 },
673 ));
674
675 self.missing_patterns.push((
676 "MAR_20%".to_string(),
677 MissingPattern::MAR {
678 missing_rate: 0.20,
679 dependency_strength: 0.3,
680 },
681 ));
682
683 self.missing_patterns.push((
684 "MNAR_10%".to_string(),
685 MissingPattern::MNAR {
686 missing_rate: 0.10,
687 threshold: 0.7,
688 },
689 ));
690
691 self.missing_patterns.push((
692 "Block_12%".to_string(),
693 MissingPattern::Block {
694 block_size: 4,
695 missing_rate: 0.12,
696 },
697 ));
698
699 self
700 }
701
702 pub fn benchmark_imputer<I, T>(
704 &self,
705 imputer: I,
706 imputer_name: &str,
707 ) -> SklResult<Vec<ImputationBenchmark>>
708 where
709 I: Clone,
710 for<'a> I: Fit<ArrayView2<'a, Float>, (), Fitted = T>,
711 for<'a> T: Transform<ArrayView2<'a, Float>, Array2<Float>>,
712 {
713 let mut results = Vec::new();
714 let pattern_generator = MissingPatternGenerator::new().random_state(self.random_state);
715
716 for (dataset_name, true_data) in &self.datasets {
717 for (pattern_name, pattern) in &self.missing_patterns {
718 let (data_with_missing, missing_mask) =
719 pattern_generator.introduce_missing(true_data, pattern)?;
720
721 let data_float = data_with_missing.mapv(|x| x as Float);
722
723 let start_time = Instant::now();
725
726 let fitted = imputer.clone().fit(&data_float.view(), &())?;
727 let imputed_data = fitted.transform(&data_float.view())?;
728
729 let execution_time = start_time.elapsed();
730
731 let imputed_f64 = imputed_data.mapv(|x| x);
733 let rmse = AccuracyMetrics::rmse(true_data, &imputed_f64, &missing_mask);
734 let mae = AccuracyMetrics::mae(true_data, &imputed_f64, &missing_mask);
735
736 let missing_rate =
737 missing_mask.iter().filter(|&&x| x).count() as f64 / missing_mask.len() as f64;
738
739 results.push(ImputationBenchmark {
740 method_name: imputer_name.to_string(),
741 dataset_name: dataset_name.clone(),
742 missing_rate,
743 missing_pattern: pattern_name.clone(),
744 rmse,
745 mae,
746 execution_time,
747 memory_usage: None,
748 convergence_iterations: None,
749 });
750 }
751 }
752
753 Ok(results)
754 }
755
756 pub fn compare_imputers(&self, benchmarks: Vec<ImputationBenchmark>) -> ImputationComparison {
758 if benchmarks.is_empty() {
759 return ImputationComparison {
760 benchmarks: Vec::new(),
761 best_rmse_method: String::new(),
762 best_mae_method: String::new(),
763 fastest_method: String::new(),
764 accuracy_rankings: HashMap::new(),
765 speed_rankings: HashMap::new(),
766 };
767 }
768
769 let best_rmse = benchmarks
771 .iter()
772 .min_by(|a, b| a.rmse.partial_cmp(&b.rmse).unwrap());
773 let best_mae = benchmarks
774 .iter()
775 .min_by(|a, b| a.mae.partial_cmp(&b.mae).unwrap());
776 let fastest = benchmarks.iter().min_by_key(|b| b.execution_time);
777
778 let best_rmse_method = best_rmse.map(|b| b.method_name.clone()).unwrap_or_default();
779 let best_mae_method = best_mae.map(|b| b.method_name.clone()).unwrap_or_default();
780 let fastest_method = fastest.map(|b| b.method_name.clone()).unwrap_or_default();
781
782 let mut accuracy_rankings = HashMap::new();
784 let mut speed_rankings = HashMap::new();
785
786 let mut method_avg_rmse: HashMap<String, f64> = HashMap::new();
788 let mut method_avg_time: HashMap<String, Duration> = HashMap::new();
789 let mut method_counts: HashMap<String, usize> = HashMap::new();
790
791 for benchmark in &benchmarks {
792 let count = method_counts
793 .entry(benchmark.method_name.clone())
794 .or_insert(0);
795 *count += 1;
796
797 let avg_rmse = method_avg_rmse
798 .entry(benchmark.method_name.clone())
799 .or_insert(0.0);
800 *avg_rmse += benchmark.rmse;
801
802 let avg_time = method_avg_time
803 .entry(benchmark.method_name.clone())
804 .or_insert(Duration::ZERO);
805 *avg_time += benchmark.execution_time;
806 }
807
808 for (method, count) in &method_counts {
810 if let Some(total_rmse) = method_avg_rmse.get_mut(method) {
811 *total_rmse /= *count as f64;
812 }
813 if let Some(total_time) = method_avg_time.get_mut(method) {
814 *total_time /= *count as u32;
815 }
816 }
817
818 let mut rmse_pairs: Vec<_> = method_avg_rmse.into_iter().collect();
820 rmse_pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
821 for (rank, (method, _)) in rmse_pairs.into_iter().enumerate() {
822 accuracy_rankings.insert(method, rank + 1);
823 }
824
825 let mut time_pairs: Vec<_> = method_avg_time.into_iter().collect();
826 time_pairs.sort_by_key(|a| a.1);
827 for (rank, (method, _)) in time_pairs.into_iter().enumerate() {
828 speed_rankings.insert(method, rank + 1);
829 }
830
831 ImputationComparison {
832 benchmarks,
833 best_rmse_method,
834 best_mae_method,
835 fastest_method,
836 accuracy_rankings,
837 speed_rankings,
838 }
839 }
840
841 pub fn generate_report(&self, comparison: &ImputationComparison) -> String {
843 let mut report = String::new();
844
845 report.push_str("# Imputation Methods Benchmark Report\n\n");
846
847 report.push_str("## Summary\n");
848 report.push_str(&format!("- Best RMSE: {}\n", comparison.best_rmse_method));
849 report.push_str(&format!("- Best MAE: {}\n", comparison.best_mae_method));
850 report.push_str(&format!("- Fastest: {}\n\n", comparison.fastest_method));
851
852 report.push_str("## Accuracy Rankings\n");
853 let mut accuracy_pairs: Vec<_> = comparison.accuracy_rankings.iter().collect();
854 accuracy_pairs.sort_by_key(|&(_, rank)| rank);
855 for (method, rank) in accuracy_pairs {
856 report.push_str(&format!("{}. {}\n", rank, method));
857 }
858
859 report.push_str("\n## Speed Rankings\n");
860 let mut speed_pairs: Vec<_> = comparison.speed_rankings.iter().collect();
861 speed_pairs.sort_by_key(|&(_, rank)| rank);
862 for (method, rank) in speed_pairs {
863 report.push_str(&format!("{}. {}\n", rank, method));
864 }
865
866 report.push_str("\n## Detailed Results\n");
867 for benchmark in &comparison.benchmarks {
868 report.push_str(&format!(
869 "- {}: {} on {} ({}): RMSE={:.4}, MAE={:.4}, Time={:.2}ms\n",
870 benchmark.method_name,
871 benchmark.missing_pattern,
872 benchmark.dataset_name,
873 (benchmark.missing_rate * 100.0).round(),
874 benchmark.rmse,
875 benchmark.mae,
876 benchmark.execution_time.as_secs_f64() * 1000.0
877 ));
878 }
879
880 report
881 }
882}
883
884impl Default for BenchmarkSuite {
885 fn default() -> Self {
886 Self::new()
887 }
888}
889
890#[allow(non_snake_case)]
891#[cfg(test)]
892mod tests {
893 use super::*;
894 use crate::{KNNImputer, SimpleImputer};
895
896 #[test]
897 fn test_dataset_generation() {
898 let generator = BenchmarkDatasetGenerator::new(50, 3).random_state(Some(42));
899
900 let linear_data = generator.generate_linear_data().unwrap();
901 assert_eq!(linear_data.shape(), &[50, 3]);
902
903 let nonlinear_data = generator.generate_nonlinear_data().unwrap();
904 assert_eq!(nonlinear_data.shape(), &[50, 3]);
905
906 let correlated_data = generator.generate_correlated_data().unwrap();
907 assert_eq!(correlated_data.shape(), &[50, 3]);
908 }
909
910 #[test]
911 fn test_missing_pattern_generation() {
912 let data = Array2::from_shape_fn((20, 3), |(i, j)| (i + j) as f64);
913 let generator = MissingPatternGenerator::new().random_state(Some(123));
914
915 let mcar_pattern = MissingPattern::MCAR { missing_rate: 0.2 };
917 let (_data_mcar, mask_mcar) = generator.introduce_missing(&data, &mcar_pattern).unwrap();
918 let missing_count = mask_mcar.iter().filter(|&&x| x).count();
919 assert!(missing_count > 0);
920 assert!(missing_count < data.len());
921
922 let mar_pattern = MissingPattern::MAR {
924 missing_rate: 0.15,
925 dependency_strength: 0.3,
926 };
927 let (_data_mar, mask_mar) = generator.introduce_missing(&data, &mar_pattern).unwrap();
928 let mar_missing_count = mask_mar.iter().filter(|&&x| x).count();
929 assert!(mar_missing_count > 0);
930
931 let block_pattern = MissingPattern::Block {
933 block_size: 4,
934 missing_rate: 0.1,
935 };
936 let (_data_block, mask_block) = generator.introduce_missing(&data, &block_pattern).unwrap();
937 let block_missing_count = mask_block.iter().filter(|&&x| x).count();
938 assert!(block_missing_count > 0);
939 }
940
941 #[test]
942 fn test_accuracy_metrics() {
943 let true_data = Array2::from_shape_fn((10, 2), |(i, j)| (i + j) as f64);
944 let mut imputed_data = true_data.clone();
945 imputed_data[[0, 0]] = 10.0; imputed_data[[1, 1]] = 20.0; let mut missing_mask = Array2::from_elem((10, 2), false);
949 missing_mask[[0, 0]] = true;
950 missing_mask[[1, 1]] = true;
951
952 let rmse = AccuracyMetrics::rmse(&true_data, &imputed_data, &missing_mask);
953 let mae = AccuracyMetrics::mae(&true_data, &imputed_data, &missing_mask);
954 let bias = AccuracyMetrics::bias(&true_data, &imputed_data, &missing_mask);
955
956 assert!(rmse > 0.0);
957 assert!(mae > 0.0);
958 assert!(bias > 0.0); }
960
961 #[test]
962 fn test_benchmark_suite() {
963 let data = Array2::from_shape_fn((30, 3), |(i, j)| (i + j) as f64);
964
965 let suite = BenchmarkSuite::new()
966 .add_dataset("test_data".to_string(), data)
967 .add_missing_pattern(
968 "test_mcar".to_string(),
969 MissingPattern::MCAR { missing_rate: 0.1 },
970 )
971 .random_state(Some(42));
972
973 let simple_imputer = SimpleImputer::new().strategy("mean".to_string());
975 let simple_results = suite
976 .benchmark_imputer(simple_imputer, "SimpleImputer")
977 .unwrap();
978
979 assert_eq!(simple_results.len(), 1);
980 assert_eq!(simple_results[0].method_name, "SimpleImputer");
981 assert!(simple_results[0].rmse >= 0.0);
982 assert!(simple_results[0].mae >= 0.0);
983
984 let knn_imputer = KNNImputer::new().n_neighbors(3);
986 let knn_results = suite.benchmark_imputer(knn_imputer, "KNNImputer").unwrap();
987
988 assert_eq!(knn_results.len(), 1);
989 assert_eq!(knn_results[0].method_name, "KNNImputer");
990
991 let all_results = [simple_results, knn_results].concat();
993 let comparison = suite.compare_imputers(all_results);
994
995 assert!(comparison.accuracy_rankings.contains_key("SimpleImputer"));
996 assert!(comparison.accuracy_rankings.contains_key("KNNImputer"));
997 assert!(comparison.speed_rankings.contains_key("SimpleImputer"));
998 assert!(comparison.speed_rankings.contains_key("KNNImputer"));
999 }
1000
1001 #[test]
1002 fn test_standard_benchmarks() {
1003 let suite = BenchmarkSuite::new()
1004 .add_standard_datasets()
1005 .add_standard_patterns()
1006 .random_state(Some(42));
1007
1008 assert!(!suite.datasets.is_empty());
1009 assert!(!suite.missing_patterns.is_empty());
1010
1011 let simple_imputer = SimpleImputer::new().strategy("mean".to_string());
1013 let results = suite
1014 .benchmark_imputer(simple_imputer, "SimpleImputer")
1015 .unwrap();
1016
1017 let expected_results = suite.datasets.len() * suite.missing_patterns.len();
1019 assert_eq!(results.len(), expected_results);
1020
1021 for result in &results {
1023 assert!(result.rmse >= 0.0);
1024 assert!(result.mae >= 0.0);
1025 assert!(result.execution_time > Duration::ZERO);
1026 }
1027 }
1028}