sklears_impute/
benchmarks.rs

1//! Benchmarking and comparison utilities for imputation methods
2//!
3//! This module provides tools for comparing imputation methods against reference implementations
4//! and measuring performance across different scenarios and datasets.
5
6use scirs2_core::ndarray::{Array2, ArrayView2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::{Random, Rng};
9use sklears_core::{
10    error::Result as SklResult,
11    traits::{Fit, Transform},
12    types::Float,
13};
14use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17/// Benchmark results for imputation methods
18#[derive(Debug, Clone)]
19pub struct ImputationBenchmark {
20    /// method_name
21    pub method_name: String,
22    /// dataset_name
23    pub dataset_name: String,
24    /// missing_rate
25    pub missing_rate: f64,
26    /// missing_pattern
27    pub missing_pattern: String,
28    /// rmse
29    pub rmse: f64,
30    /// mae
31    pub mae: f64,
32    /// execution_time
33    pub execution_time: Duration,
34    /// memory_usage
35    pub memory_usage: Option<usize>,
36    /// convergence_iterations
37    pub convergence_iterations: Option<usize>,
38}
39
40/// Comparison results between methods
41#[derive(Debug, Clone)]
42pub struct ImputationComparison {
43    /// benchmarks
44    pub benchmarks: Vec<ImputationBenchmark>,
45    /// best_rmse_method
46    pub best_rmse_method: String,
47    /// best_mae_method
48    pub best_mae_method: String,
49    /// fastest_method
50    pub fastest_method: String,
51    /// accuracy_rankings
52    pub accuracy_rankings: HashMap<String, usize>,
53    /// speed_rankings
54    pub speed_rankings: HashMap<String, usize>,
55}
56
57/// Missing data pattern types for benchmarking
58#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
59pub enum MissingPattern {
60    /// Missing Completely At Random
61    MCAR { missing_rate: f64 },
62    /// Missing At Random - depends on observed values
63    MAR {
64        missing_rate: f64,
65        dependency_strength: f64,
66    },
67    /// Missing Not At Random - depends on unobserved values
68    MNAR { missing_rate: f64, threshold: f64 },
69    /// Block missing pattern
70    Block {
71        block_size: usize,
72        missing_rate: f64,
73    },
74    /// Monotone missing pattern
75    Monotone { missing_rate: f64 },
76}
77
78/// Dataset generator for benchmarking
79pub struct BenchmarkDatasetGenerator {
80    n_samples: usize,
81    n_features: usize,
82    noise_level: f64,
83    correlation_strength: f64,
84    random_state: Option<u64>,
85}
86
87impl BenchmarkDatasetGenerator {
88    /// Create a new dataset generator
89    pub fn new(n_samples: usize, n_features: usize) -> Self {
90        Self {
91            n_samples,
92            n_features,
93            noise_level: 0.1,
94            correlation_strength: 0.5,
95            random_state: None,
96        }
97    }
98
99    /// Set noise level
100    pub fn noise_level(mut self, noise_level: f64) -> Self {
101        self.noise_level = noise_level;
102        self
103    }
104
105    /// Set correlation strength between features
106    pub fn correlation_strength(mut self, correlation_strength: f64) -> Self {
107        self.correlation_strength = correlation_strength;
108        self
109    }
110
111    /// Set random state
112    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
113        self.random_state = random_state;
114        self
115    }
116
117    /// Generate a correlated multivariate dataset
118    pub fn generate_correlated_data(&self) -> SklResult<Array2<f64>> {
119        let mut rng = if let Some(_seed) = self.random_state {
120            Random::default()
121        } else {
122            Random::default()
123        };
124
125        let mut data = Array2::zeros((self.n_samples, self.n_features));
126
127        // Generate base features
128        for i in 0..self.n_samples {
129            data[[i, 0]] = rng.gen_range(-3.0..3.0);
130        }
131
132        // Generate correlated features
133        for j in 1..self.n_features {
134            let correlation = self.correlation_strength;
135            for i in 0..self.n_samples {
136                let base_value = data[[i, 0]];
137                let noise = rng.gen_range(-self.noise_level..self.noise_level);
138                data[[i, j]] = correlation * base_value
139                    + (1.0 - correlation) * rng.gen_range(-2.0..2.0)
140                    + noise;
141            }
142        }
143
144        Ok(data)
145    }
146
147    /// Generate linear relationship dataset
148    pub fn generate_linear_data(&self) -> SklResult<Array2<f64>> {
149        let mut rng = if let Some(_seed) = self.random_state {
150            Random::default()
151        } else {
152            Random::default()
153        };
154
155        let mut data = Array2::zeros((self.n_samples, self.n_features));
156
157        // Generate features with linear relationships
158        for i in 0..self.n_samples {
159            // Base feature
160            data[[i, 0]] = rng.gen_range(-5.0..5.0);
161
162            // Linearly related features
163            for j in 1..self.n_features {
164                let coef = (j as f64) * 0.5;
165                let noise = rng.gen_range(-self.noise_level..self.noise_level);
166                data[[i, j]] = coef * data[[i, 0]] + noise;
167            }
168        }
169
170        Ok(data)
171    }
172
173    /// Generate non-linear relationship dataset
174    pub fn generate_nonlinear_data(&self) -> SklResult<Array2<f64>> {
175        let mut rng = if let Some(_seed) = self.random_state {
176            Random::default()
177        } else {
178            Random::default()
179        };
180
181        let mut data = Array2::zeros((self.n_samples, self.n_features));
182
183        for i in 0..self.n_samples {
184            let x: f64 = rng.gen_range(-2.0..2.0);
185            data[[i, 0]] = x;
186
187            // Non-linear relationships
188            data[[i, 1]] = x.powi(2) + rng.gen_range(-self.noise_level..self.noise_level);
189
190            if self.n_features > 2 {
191                data[[i, 2]] = (x * 1.5).sin() + rng.gen_range(-self.noise_level..self.noise_level);
192            }
193
194            if self.n_features > 3 {
195                data[[i, 3]] = (x.powi(2) + x).exp() / 10.0
196                    + rng.gen_range(-self.noise_level..self.noise_level);
197            }
198
199            // Additional features with mixed relationships
200            for j in 4..self.n_features {
201                let noise = rng.gen_range(-self.noise_level..self.noise_level);
202                data[[i, j]] = (x + (j as f64) * 0.2).cos() + noise;
203            }
204        }
205
206        Ok(data)
207    }
208}
209
210/// Missing pattern generator
211pub struct MissingPatternGenerator {
212    random_state: Option<u64>,
213}
214
215impl MissingPatternGenerator {
216    /// Create a new missing pattern generator
217    pub fn new() -> Self {
218        Self { random_state: None }
219    }
220
221    /// Set random state
222    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
223        self.random_state = random_state;
224        self
225    }
226
227    /// Introduce missing values according to pattern
228    pub fn introduce_missing(
229        &self,
230        data: &Array2<f64>,
231        pattern: &MissingPattern,
232    ) -> SklResult<(Array2<f64>, Array2<bool>)> {
233        let mut rng = if let Some(_seed) = self.random_state {
234            Random::default()
235        } else {
236            Random::default()
237        };
238
239        let (n_samples, n_features) = data.dim();
240        let mut data_with_missing = data.clone();
241        let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
242
243        match pattern {
244            MissingPattern::MCAR { missing_rate } => {
245                self.introduce_mcar(
246                    &mut data_with_missing,
247                    &mut missing_mask,
248                    *missing_rate,
249                    &mut rng,
250                )?;
251            }
252            MissingPattern::MAR {
253                missing_rate,
254                dependency_strength,
255            } => {
256                self.introduce_mar(
257                    data,
258                    &mut data_with_missing,
259                    &mut missing_mask,
260                    *missing_rate,
261                    *dependency_strength,
262                    &mut rng,
263                )?;
264            }
265            MissingPattern::MNAR {
266                missing_rate,
267                threshold,
268            } => {
269                self.introduce_mnar(
270                    data,
271                    &mut data_with_missing,
272                    &mut missing_mask,
273                    *missing_rate,
274                    *threshold,
275                    &mut rng,
276                )?;
277            }
278            MissingPattern::Block {
279                block_size,
280                missing_rate,
281            } => {
282                self.introduce_block(
283                    &mut data_with_missing,
284                    &mut missing_mask,
285                    *block_size,
286                    *missing_rate,
287                    &mut rng,
288                )?;
289            }
290            MissingPattern::Monotone { missing_rate } => {
291                self.introduce_monotone(
292                    &mut data_with_missing,
293                    &mut missing_mask,
294                    *missing_rate,
295                    &mut rng,
296                )?;
297            }
298        }
299
300        Ok((data_with_missing, missing_mask))
301    }
302
303    fn introduce_mcar(
304        &self,
305        data: &mut Array2<f64>,
306        missing_mask: &mut Array2<bool>,
307        missing_rate: f64,
308        rng: &mut Random,
309    ) -> SklResult<()> {
310        let total_elements = data.len();
311        let n_missing = (total_elements as f64 * missing_rate) as usize;
312
313        let mut positions: Vec<(usize, usize)> = Vec::new();
314        for i in 0..data.nrows() {
315            for j in 0..data.ncols() {
316                positions.push((i, j));
317            }
318        }
319
320        positions.shuffle(rng);
321
322        for &(i, j) in positions.iter().take(n_missing) {
323            data[[i, j]] = f64::NAN;
324            missing_mask[[i, j]] = true;
325        }
326
327        Ok(())
328    }
329
330    fn introduce_mar(
331        &self,
332        original_data: &Array2<f64>,
333        data: &mut Array2<f64>,
334        missing_mask: &mut Array2<bool>,
335        missing_rate: f64,
336        dependency_strength: f64,
337        rng: &mut Random,
338    ) -> SklResult<()> {
339        let (n_samples, n_features) = data.dim();
340
341        if n_features < 2 {
342            return self.introduce_mcar(data, missing_mask, missing_rate, rng);
343        }
344
345        // Make missingness in columns 1+ depend on column 0
346        let column_0_median = {
347            let mut sorted: Vec<f64> = original_data.column(0).iter().cloned().collect();
348            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
349            sorted[sorted.len() / 2]
350        };
351
352        for i in 0..n_samples {
353            for j in 1..n_features {
354                let base_prob = missing_rate;
355                let prob_adjustment = if original_data[[i, 0]] > column_0_median {
356                    dependency_strength
357                } else {
358                    -dependency_strength
359                };
360
361                let prob_missing = (base_prob + prob_adjustment).clamp(0.0, 1.0);
362
363                if rng.gen::<f64>() < prob_missing {
364                    data[[i, j]] = f64::NAN;
365                    missing_mask[[i, j]] = true;
366                }
367            }
368        }
369
370        Ok(())
371    }
372
373    fn introduce_mnar(
374        &self,
375        original_data: &Array2<f64>,
376        data: &mut Array2<f64>,
377        missing_mask: &mut Array2<bool>,
378        missing_rate: f64,
379        threshold: f64,
380        rng: &mut Random,
381    ) -> SklResult<()> {
382        let (n_samples, n_features) = data.dim();
383
384        for j in 0..n_features {
385            let column_values: Vec<f64> = original_data.column(j).iter().cloned().collect();
386            let column_threshold = {
387                let mut sorted = column_values.clone();
388                sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
389                sorted[(sorted.len() as f64 * threshold) as usize]
390            };
391
392            for i in 0..n_samples {
393                // Higher chance of missing if value is above threshold
394                let base_prob = missing_rate;
395                let prob_missing = if original_data[[i, j]] > column_threshold {
396                    base_prob * 2.0
397                } else {
398                    base_prob * 0.5
399                };
400
401                if rng.gen::<f64>() < prob_missing.min(1.0) {
402                    data[[i, j]] = f64::NAN;
403                    missing_mask[[i, j]] = true;
404                }
405            }
406        }
407
408        Ok(())
409    }
410
411    fn introduce_block(
412        &self,
413        data: &mut Array2<f64>,
414        missing_mask: &mut Array2<bool>,
415        block_size: usize,
416        missing_rate: f64,
417        rng: &mut Random,
418    ) -> SklResult<()> {
419        let (n_samples, n_features) = data.dim();
420        let n_blocks =
421            ((n_samples * n_features) as f64 * missing_rate / block_size as f64) as usize;
422
423        for _ in 0..n_blocks {
424            let start_i = rng.gen_range(0..n_samples);
425            let start_j = rng.gen_range(0..n_features);
426
427            let block_height = (block_size as f64).sqrt() as usize;
428            let block_width = block_size / block_height.max(1);
429
430            for di in 0..block_height {
431                for dj in 0..block_width {
432                    let i = (start_i + di) % n_samples;
433                    let j = (start_j + dj) % n_features;
434                    data[[i, j]] = f64::NAN;
435                    missing_mask[[i, j]] = true;
436                }
437            }
438        }
439
440        Ok(())
441    }
442
443    fn introduce_monotone(
444        &self,
445        data: &mut Array2<f64>,
446        missing_mask: &mut Array2<bool>,
447        missing_rate: f64,
448        rng: &mut Random,
449    ) -> SklResult<()> {
450        let (n_samples, n_features) = data.dim();
451
452        if n_features == 0 {
453            return Ok(());
454        }
455
456        let mut samples_to_affect: Vec<usize> = (0..n_samples).collect();
457        samples_to_affect.shuffle(rng);
458        let n_affected = (n_samples as f64 * missing_rate) as usize;
459
460        for &sample_idx in samples_to_affect.iter().take(n_affected) {
461            // Start missing from a random feature onwards
462            let start_feature = rng.gen_range(0..n_features);
463            for j in start_feature..n_features {
464                data[[sample_idx, j]] = f64::NAN;
465                missing_mask[[sample_idx, j]] = true;
466            }
467        }
468
469        Ok(())
470    }
471}
472
473impl Default for MissingPatternGenerator {
474    fn default() -> Self {
475        Self::new()
476    }
477}
478
479/// Accuracy metrics calculator
480pub struct AccuracyMetrics;
481
482impl AccuracyMetrics {
483    /// Calculate Root Mean Square Error
484    pub fn rmse(
485        true_values: &Array2<f64>,
486        imputed_values: &Array2<f64>,
487        missing_mask: &Array2<bool>,
488    ) -> f64 {
489        let mut sum_squared_diff = 0.0;
490        let mut count = 0;
491
492        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
493            if is_missing {
494                let diff = true_values[[i, j]] - imputed_values[[i, j]];
495                sum_squared_diff += diff * diff;
496                count += 1;
497            }
498        }
499
500        if count > 0 {
501            (sum_squared_diff / count as f64).sqrt()
502        } else {
503            0.0
504        }
505    }
506
507    /// Calculate Mean Absolute Error
508    pub fn mae(
509        true_values: &Array2<f64>,
510        imputed_values: &Array2<f64>,
511        missing_mask: &Array2<bool>,
512    ) -> f64 {
513        let mut sum_abs_diff = 0.0;
514        let mut count = 0;
515
516        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
517            if is_missing {
518                let diff = (true_values[[i, j]] - imputed_values[[i, j]]).abs();
519                sum_abs_diff += diff;
520                count += 1;
521            }
522        }
523
524        if count > 0 {
525            sum_abs_diff / count as f64
526        } else {
527            0.0
528        }
529    }
530
531    /// Calculate bias (mean error)
532    pub fn bias(
533        true_values: &Array2<f64>,
534        imputed_values: &Array2<f64>,
535        missing_mask: &Array2<bool>,
536    ) -> f64 {
537        let mut sum_diff = 0.0;
538        let mut count = 0;
539
540        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
541            if is_missing {
542                let diff = imputed_values[[i, j]] - true_values[[i, j]];
543                sum_diff += diff;
544                count += 1;
545            }
546        }
547
548        if count > 0 {
549            sum_diff / count as f64
550        } else {
551            0.0
552        }
553    }
554
555    /// Calculate R-squared coefficient
556    pub fn r_squared(
557        true_values: &Array2<f64>,
558        imputed_values: &Array2<f64>,
559        missing_mask: &Array2<bool>,
560    ) -> f64 {
561        let mut missing_true_values = Vec::new();
562        let mut missing_imputed_values = Vec::new();
563
564        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
565            if is_missing {
566                missing_true_values.push(true_values[[i, j]]);
567                missing_imputed_values.push(imputed_values[[i, j]]);
568            }
569        }
570
571        if missing_true_values.is_empty() {
572            return 1.0;
573        }
574
575        let true_mean = missing_true_values.iter().sum::<f64>() / missing_true_values.len() as f64;
576
577        let ss_tot: f64 = missing_true_values
578            .iter()
579            .map(|&x| (x - true_mean).powi(2))
580            .sum();
581
582        let ss_res: f64 = missing_true_values
583            .iter()
584            .zip(missing_imputed_values.iter())
585            .map(|(&true_val, &imputed_val)| (true_val - imputed_val).powi(2))
586            .sum();
587
588        if ss_tot == 0.0 {
589            1.0
590        } else {
591            1.0 - (ss_res / ss_tot)
592        }
593    }
594}
595
596/// Performance benchmarking suite
597pub struct BenchmarkSuite {
598    datasets: Vec<(String, Array2<f64>)>,
599    missing_patterns: Vec<(String, MissingPattern)>,
600    random_state: Option<u64>,
601}
602
603impl BenchmarkSuite {
604    /// Create a new benchmark suite
605    pub fn new() -> Self {
606        Self {
607            datasets: Vec::new(),
608            missing_patterns: Vec::new(),
609            random_state: None,
610        }
611    }
612
613    /// Add a dataset to the benchmark suite
614    pub fn add_dataset(mut self, name: String, data: Array2<f64>) -> Self {
615        self.datasets.push((name, data));
616        self
617    }
618
619    /// Add a missing pattern to test
620    pub fn add_missing_pattern(mut self, name: String, pattern: MissingPattern) -> Self {
621        self.missing_patterns.push((name, pattern));
622        self
623    }
624
625    /// Set random state
626    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
627        self.random_state = random_state;
628        self
629    }
630
631    /// Add standard benchmark datasets
632    pub fn add_standard_datasets(mut self) -> Self {
633        // Linear relationship dataset
634        let linear_gen = BenchmarkDatasetGenerator::new(100, 4)
635            .correlation_strength(0.7)
636            .noise_level(0.1)
637            .random_state(Some(42));
638
639        if let Ok(linear_data) = linear_gen.generate_linear_data() {
640            self.datasets
641                .push(("linear_100x4".to_string(), linear_data));
642        }
643
644        // Non-linear relationship dataset
645        let nonlinear_gen = BenchmarkDatasetGenerator::new(80, 3)
646            .noise_level(0.2)
647            .random_state(Some(123));
648
649        if let Ok(nonlinear_data) = nonlinear_gen.generate_nonlinear_data() {
650            self.datasets
651                .push(("nonlinear_80x3".to_string(), nonlinear_data));
652        }
653
654        // Correlated dataset
655        let correlated_gen = BenchmarkDatasetGenerator::new(120, 5)
656            .correlation_strength(0.8)
657            .noise_level(0.05)
658            .random_state(Some(456));
659
660        if let Ok(correlated_data) = correlated_gen.generate_correlated_data() {
661            self.datasets
662                .push(("correlated_120x5".to_string(), correlated_data));
663        }
664
665        self
666    }
667
668    /// Add standard missing patterns
669    pub fn add_standard_patterns(mut self) -> Self {
670        self.missing_patterns.push((
671            "MCAR_15%".to_string(),
672            MissingPattern::MCAR { missing_rate: 0.15 },
673        ));
674
675        self.missing_patterns.push((
676            "MAR_20%".to_string(),
677            MissingPattern::MAR {
678                missing_rate: 0.20,
679                dependency_strength: 0.3,
680            },
681        ));
682
683        self.missing_patterns.push((
684            "MNAR_10%".to_string(),
685            MissingPattern::MNAR {
686                missing_rate: 0.10,
687                threshold: 0.7,
688            },
689        ));
690
691        self.missing_patterns.push((
692            "Block_12%".to_string(),
693            MissingPattern::Block {
694                block_size: 4,
695                missing_rate: 0.12,
696            },
697        ));
698
699        self
700    }
701
702    /// Run benchmark on a specific imputer
703    pub fn benchmark_imputer<I, T>(
704        &self,
705        imputer: I,
706        imputer_name: &str,
707    ) -> SklResult<Vec<ImputationBenchmark>>
708    where
709        I: Clone,
710        for<'a> I: Fit<ArrayView2<'a, Float>, (), Fitted = T>,
711        for<'a> T: Transform<ArrayView2<'a, Float>, Array2<Float>>,
712    {
713        let mut results = Vec::new();
714        let pattern_generator = MissingPatternGenerator::new().random_state(self.random_state);
715
716        for (dataset_name, true_data) in &self.datasets {
717            for (pattern_name, pattern) in &self.missing_patterns {
718                let (data_with_missing, missing_mask) =
719                    pattern_generator.introduce_missing(true_data, pattern)?;
720
721                let data_float = data_with_missing.mapv(|x| x as Float);
722
723                // Measure execution time
724                let start_time = Instant::now();
725
726                let fitted = imputer.clone().fit(&data_float.view(), &())?;
727                let imputed_data = fitted.transform(&data_float.view())?;
728
729                let execution_time = start_time.elapsed();
730
731                // Calculate accuracy metrics
732                let imputed_f64 = imputed_data.mapv(|x| x);
733                let rmse = AccuracyMetrics::rmse(true_data, &imputed_f64, &missing_mask);
734                let mae = AccuracyMetrics::mae(true_data, &imputed_f64, &missing_mask);
735
736                let missing_rate =
737                    missing_mask.iter().filter(|&&x| x).count() as f64 / missing_mask.len() as f64;
738
739                results.push(ImputationBenchmark {
740                    method_name: imputer_name.to_string(),
741                    dataset_name: dataset_name.clone(),
742                    missing_rate,
743                    missing_pattern: pattern_name.clone(),
744                    rmse,
745                    mae,
746                    execution_time,
747                    memory_usage: None,
748                    convergence_iterations: None,
749                });
750            }
751        }
752
753        Ok(results)
754    }
755
756    /// Compare multiple imputers
757    pub fn compare_imputers(&self, benchmarks: Vec<ImputationBenchmark>) -> ImputationComparison {
758        if benchmarks.is_empty() {
759            return ImputationComparison {
760                benchmarks: Vec::new(),
761                best_rmse_method: String::new(),
762                best_mae_method: String::new(),
763                fastest_method: String::new(),
764                accuracy_rankings: HashMap::new(),
765                speed_rankings: HashMap::new(),
766            };
767        }
768
769        // Find best performing methods
770        let best_rmse = benchmarks
771            .iter()
772            .min_by(|a, b| a.rmse.partial_cmp(&b.rmse).unwrap());
773        let best_mae = benchmarks
774            .iter()
775            .min_by(|a, b| a.mae.partial_cmp(&b.mae).unwrap());
776        let fastest = benchmarks.iter().min_by_key(|b| b.execution_time);
777
778        let best_rmse_method = best_rmse.map(|b| b.method_name.clone()).unwrap_or_default();
779        let best_mae_method = best_mae.map(|b| b.method_name.clone()).unwrap_or_default();
780        let fastest_method = fastest.map(|b| b.method_name.clone()).unwrap_or_default();
781
782        // Calculate rankings
783        let mut accuracy_rankings = HashMap::new();
784        let mut speed_rankings = HashMap::new();
785
786        // Group benchmarks by method
787        let mut method_avg_rmse: HashMap<String, f64> = HashMap::new();
788        let mut method_avg_time: HashMap<String, Duration> = HashMap::new();
789        let mut method_counts: HashMap<String, usize> = HashMap::new();
790
791        for benchmark in &benchmarks {
792            let count = method_counts
793                .entry(benchmark.method_name.clone())
794                .or_insert(0);
795            *count += 1;
796
797            let avg_rmse = method_avg_rmse
798                .entry(benchmark.method_name.clone())
799                .or_insert(0.0);
800            *avg_rmse += benchmark.rmse;
801
802            let avg_time = method_avg_time
803                .entry(benchmark.method_name.clone())
804                .or_insert(Duration::ZERO);
805            *avg_time += benchmark.execution_time;
806        }
807
808        // Calculate averages
809        for (method, count) in &method_counts {
810            if let Some(total_rmse) = method_avg_rmse.get_mut(method) {
811                *total_rmse /= *count as f64;
812            }
813            if let Some(total_time) = method_avg_time.get_mut(method) {
814                *total_time /= *count as u32;
815            }
816        }
817
818        // Create sorted rankings
819        let mut rmse_pairs: Vec<_> = method_avg_rmse.into_iter().collect();
820        rmse_pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
821        for (rank, (method, _)) in rmse_pairs.into_iter().enumerate() {
822            accuracy_rankings.insert(method, rank + 1);
823        }
824
825        let mut time_pairs: Vec<_> = method_avg_time.into_iter().collect();
826        time_pairs.sort_by_key(|a| a.1);
827        for (rank, (method, _)) in time_pairs.into_iter().enumerate() {
828            speed_rankings.insert(method, rank + 1);
829        }
830
831        ImputationComparison {
832            benchmarks,
833            best_rmse_method,
834            best_mae_method,
835            fastest_method,
836            accuracy_rankings,
837            speed_rankings,
838        }
839    }
840
841    /// Generate a comprehensive benchmark report
842    pub fn generate_report(&self, comparison: &ImputationComparison) -> String {
843        let mut report = String::new();
844
845        report.push_str("# Imputation Methods Benchmark Report\n\n");
846
847        report.push_str("## Summary\n");
848        report.push_str(&format!("- Best RMSE: {}\n", comparison.best_rmse_method));
849        report.push_str(&format!("- Best MAE: {}\n", comparison.best_mae_method));
850        report.push_str(&format!("- Fastest: {}\n\n", comparison.fastest_method));
851
852        report.push_str("## Accuracy Rankings\n");
853        let mut accuracy_pairs: Vec<_> = comparison.accuracy_rankings.iter().collect();
854        accuracy_pairs.sort_by_key(|&(_, rank)| rank);
855        for (method, rank) in accuracy_pairs {
856            report.push_str(&format!("{}. {}\n", rank, method));
857        }
858
859        report.push_str("\n## Speed Rankings\n");
860        let mut speed_pairs: Vec<_> = comparison.speed_rankings.iter().collect();
861        speed_pairs.sort_by_key(|&(_, rank)| rank);
862        for (method, rank) in speed_pairs {
863            report.push_str(&format!("{}. {}\n", rank, method));
864        }
865
866        report.push_str("\n## Detailed Results\n");
867        for benchmark in &comparison.benchmarks {
868            report.push_str(&format!(
869                "- {}: {} on {} ({}): RMSE={:.4}, MAE={:.4}, Time={:.2}ms\n",
870                benchmark.method_name,
871                benchmark.missing_pattern,
872                benchmark.dataset_name,
873                (benchmark.missing_rate * 100.0).round(),
874                benchmark.rmse,
875                benchmark.mae,
876                benchmark.execution_time.as_secs_f64() * 1000.0
877            ));
878        }
879
880        report
881    }
882}
883
884impl Default for BenchmarkSuite {
885    fn default() -> Self {
886        Self::new()
887    }
888}
889
890#[allow(non_snake_case)]
891#[cfg(test)]
892mod tests {
893    use super::*;
894    use crate::{KNNImputer, SimpleImputer};
895
896    #[test]
897    fn test_dataset_generation() {
898        let generator = BenchmarkDatasetGenerator::new(50, 3).random_state(Some(42));
899
900        let linear_data = generator.generate_linear_data().unwrap();
901        assert_eq!(linear_data.shape(), &[50, 3]);
902
903        let nonlinear_data = generator.generate_nonlinear_data().unwrap();
904        assert_eq!(nonlinear_data.shape(), &[50, 3]);
905
906        let correlated_data = generator.generate_correlated_data().unwrap();
907        assert_eq!(correlated_data.shape(), &[50, 3]);
908    }
909
910    #[test]
911    fn test_missing_pattern_generation() {
912        let data = Array2::from_shape_fn((20, 3), |(i, j)| (i + j) as f64);
913        let generator = MissingPatternGenerator::new().random_state(Some(123));
914
915        // Test MCAR pattern
916        let mcar_pattern = MissingPattern::MCAR { missing_rate: 0.2 };
917        let (_data_mcar, mask_mcar) = generator.introduce_missing(&data, &mcar_pattern).unwrap();
918        let missing_count = mask_mcar.iter().filter(|&&x| x).count();
919        assert!(missing_count > 0);
920        assert!(missing_count < data.len());
921
922        // Test MAR pattern
923        let mar_pattern = MissingPattern::MAR {
924            missing_rate: 0.15,
925            dependency_strength: 0.3,
926        };
927        let (_data_mar, mask_mar) = generator.introduce_missing(&data, &mar_pattern).unwrap();
928        let mar_missing_count = mask_mar.iter().filter(|&&x| x).count();
929        assert!(mar_missing_count > 0);
930
931        // Test Block pattern
932        let block_pattern = MissingPattern::Block {
933            block_size: 4,
934            missing_rate: 0.1,
935        };
936        let (_data_block, mask_block) = generator.introduce_missing(&data, &block_pattern).unwrap();
937        let block_missing_count = mask_block.iter().filter(|&&x| x).count();
938        assert!(block_missing_count > 0);
939    }
940
941    #[test]
942    fn test_accuracy_metrics() {
943        let true_data = Array2::from_shape_fn((10, 2), |(i, j)| (i + j) as f64);
944        let mut imputed_data = true_data.clone();
945        imputed_data[[0, 0]] = 10.0; // Introduce error
946        imputed_data[[1, 1]] = 20.0; // Introduce error
947
948        let mut missing_mask = Array2::from_elem((10, 2), false);
949        missing_mask[[0, 0]] = true;
950        missing_mask[[1, 1]] = true;
951
952        let rmse = AccuracyMetrics::rmse(&true_data, &imputed_data, &missing_mask);
953        let mae = AccuracyMetrics::mae(&true_data, &imputed_data, &missing_mask);
954        let bias = AccuracyMetrics::bias(&true_data, &imputed_data, &missing_mask);
955
956        assert!(rmse > 0.0);
957        assert!(mae > 0.0);
958        assert!(bias > 0.0); // Positive bias since imputed values are higher
959    }
960
961    #[test]
962    fn test_benchmark_suite() {
963        let data = Array2::from_shape_fn((30, 3), |(i, j)| (i + j) as f64);
964
965        let suite = BenchmarkSuite::new()
966            .add_dataset("test_data".to_string(), data)
967            .add_missing_pattern(
968                "test_mcar".to_string(),
969                MissingPattern::MCAR { missing_rate: 0.1 },
970            )
971            .random_state(Some(42));
972
973        // Test simple imputer
974        let simple_imputer = SimpleImputer::new().strategy("mean".to_string());
975        let simple_results = suite
976            .benchmark_imputer(simple_imputer, "SimpleImputer")
977            .unwrap();
978
979        assert_eq!(simple_results.len(), 1);
980        assert_eq!(simple_results[0].method_name, "SimpleImputer");
981        assert!(simple_results[0].rmse >= 0.0);
982        assert!(simple_results[0].mae >= 0.0);
983
984        // Test KNN imputer
985        let knn_imputer = KNNImputer::new().n_neighbors(3);
986        let knn_results = suite.benchmark_imputer(knn_imputer, "KNNImputer").unwrap();
987
988        assert_eq!(knn_results.len(), 1);
989        assert_eq!(knn_results[0].method_name, "KNNImputer");
990
991        // Test comparison
992        let all_results = [simple_results, knn_results].concat();
993        let comparison = suite.compare_imputers(all_results);
994
995        assert!(comparison.accuracy_rankings.contains_key("SimpleImputer"));
996        assert!(comparison.accuracy_rankings.contains_key("KNNImputer"));
997        assert!(comparison.speed_rankings.contains_key("SimpleImputer"));
998        assert!(comparison.speed_rankings.contains_key("KNNImputer"));
999    }
1000
1001    #[test]
1002    fn test_standard_benchmarks() {
1003        let suite = BenchmarkSuite::new()
1004            .add_standard_datasets()
1005            .add_standard_patterns()
1006            .random_state(Some(42));
1007
1008        assert!(!suite.datasets.is_empty());
1009        assert!(!suite.missing_patterns.is_empty());
1010
1011        // Test with a simple imputer
1012        let simple_imputer = SimpleImputer::new().strategy("mean".to_string());
1013        let results = suite
1014            .benchmark_imputer(simple_imputer, "SimpleImputer")
1015            .unwrap();
1016
1017        // Should have results for each dataset × pattern combination
1018        let expected_results = suite.datasets.len() * suite.missing_patterns.len();
1019        assert_eq!(results.len(), expected_results);
1020
1021        // All results should have valid metrics
1022        for result in &results {
1023            assert!(result.rmse >= 0.0);
1024            assert!(result.mae >= 0.0);
1025            assert!(result.execution_time > Duration::ZERO);
1026        }
1027    }
1028}