sklears_dummy/
sklearn_benchmarks.rs

1//! Comprehensive benchmarking framework comparing against scikit-learn dummy estimators
2//!
3//! This module provides a comprehensive benchmarking framework that:
4//! - Compares accuracy and behavior against scikit-learn dummy estimators
5//! - Measures performance (speed, memory usage) of implementations
6//! - Tests on standard and synthetic datasets
7//! - Generates detailed comparison reports
8//! - Validates numerical accuracy and consistency
9
10use crate::{ClassifierStrategy, DummyClassifier, DummyRegressor, RegressorStrategy};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::{Rng, SeedableRng};
13use sklears_core::{error::SklearsError, traits::Estimator, traits::Fit, traits::Predict};
14use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17/// Results from benchmarking a dummy estimator
18#[derive(Debug, Clone)]
19pub struct BenchmarkResult {
20    /// strategy
21    pub strategy: String,
22    /// accuracy_comparison
23    pub accuracy_comparison: AccuracyComparison,
24    /// performance_metrics
25    pub performance_metrics: PerformanceMetrics,
26    /// numerical_accuracy
27    pub numerical_accuracy: NumericalAccuracy,
28    /// dataset_info
29    pub dataset_info: DatasetInfo,
30}
31
32/// Comparison of accuracy between sklears and reference implementation
33#[derive(Debug, Clone)]
34pub struct AccuracyComparison {
35    /// sklears_score
36    pub sklears_score: f64,
37    /// reference_score
38    pub reference_score: f64,
39    /// absolute_difference
40    pub absolute_difference: f64,
41    /// relative_difference
42    pub relative_difference: f64,
43    /// within_tolerance
44    pub within_tolerance: bool,
45    /// tolerance_used
46    pub tolerance_used: f64,
47}
48
49/// Performance metrics for the estimator
50#[derive(Debug, Clone)]
51pub struct PerformanceMetrics {
52    /// fit_time_sklears
53    pub fit_time_sklears: Duration,
54    /// predict_time_sklears
55    pub predict_time_sklears: Duration,
56    /// fit_time_reference
57    pub fit_time_reference: Duration,
58    /// predict_time_reference
59    pub predict_time_reference: Duration,
60    /// speedup_fit
61    pub speedup_fit: f64,
62    /// speedup_predict
63    pub speedup_predict: f64,
64    /// memory_usage_sklears
65    pub memory_usage_sklears: usize,
66    /// memory_usage_reference
67    pub memory_usage_reference: usize,
68}
69
70/// Numerical accuracy comparison
71#[derive(Debug, Clone)]
72pub struct NumericalAccuracy {
73    /// prediction_mse
74    pub prediction_mse: f64,
75    /// prediction_mae
76    pub prediction_mae: f64,
77    /// max_absolute_error
78    pub max_absolute_error: f64,
79    /// correlation
80    pub correlation: f64,
81    /// reproducibility_check
82    pub reproducibility_check: bool,
83}
84
85/// Information about the dataset used for benchmarking
86#[derive(Debug, Clone)]
87pub struct DatasetInfo {
88    /// name
89    pub name: String,
90    /// n_samples
91    pub n_samples: usize,
92    /// n_features
93    pub n_features: usize,
94    /// n_classes
95    pub n_classes: Option<usize>,
96    /// class_distribution
97    pub class_distribution: Option<HashMap<i32, usize>>,
98    /// target_statistics
99    pub target_statistics: Option<TargetStatistics>,
100}
101
102/// Statistics about regression targets
103#[derive(Debug, Clone)]
104pub struct TargetStatistics {
105    /// mean
106    pub mean: f64,
107    /// std
108    pub std: f64,
109    /// min
110    pub min: f64,
111    /// max
112    pub max: f64,
113    /// skewness
114    pub skewness: f64,
115    /// kurtosis
116    pub kurtosis: f64,
117}
118
119/// Configuration for benchmarking
120#[derive(Debug, Clone)]
121pub struct BenchmarkConfig {
122    /// tolerance
123    pub tolerance: f64,
124    /// n_runs
125    pub n_runs: usize,
126    /// random_state
127    pub random_state: Option<u64>,
128    /// include_performance
129    pub include_performance: bool,
130    /// include_memory
131    pub include_memory: bool,
132    /// test_reproducibility
133    pub test_reproducibility: bool,
134    /// datasets
135    pub datasets: Vec<DatasetConfig>,
136}
137
138/// Configuration for a dataset
139#[derive(Debug, Clone)]
140pub struct DatasetConfig {
141    /// name
142    pub name: String,
143    /// data_type
144    pub data_type: DatasetType,
145    /// size
146    pub size: DatasetSize,
147    /// properties
148    pub properties: DatasetProperties,
149}
150
151/// Type of dataset
152#[derive(Debug, Clone)]
153pub enum DatasetType {
154    /// Classification
155    Classification { n_classes: usize },
156    /// Regression
157    Regression,
158    /// Multiclass
159    Multiclass { n_classes: usize },
160    /// Imbalanced
161    Imbalanced { majority_ratio: f64 },
162}
163
164/// Size of dataset
165#[derive(Debug, Clone)]
166pub struct DatasetSize {
167    /// n_samples
168    pub n_samples: usize,
169    /// n_features
170    pub n_features: usize,
171}
172
173/// Properties of synthetic dataset
174#[derive(Debug, Clone)]
175pub struct DatasetProperties {
176    /// noise_level
177    pub noise_level: f64,
178    /// correlation
179    pub correlation: f64,
180    /// outlier_fraction
181    pub outlier_fraction: f64,
182    /// random_state
183    pub random_state: Option<u64>,
184}
185
186/// Comprehensive benchmarking framework
187pub struct SklearnBenchmarkFramework {
188    config: BenchmarkConfig,
189}
190
191impl Default for BenchmarkConfig {
192    fn default() -> Self {
193        Self {
194            tolerance: 1e-10,
195            n_runs: 5,
196            random_state: Some(42),
197            include_performance: true,
198            include_memory: false, // Requires memory profiling tools
199            test_reproducibility: true,
200            datasets: Self::default_datasets(),
201        }
202    }
203}
204
205impl BenchmarkConfig {
206    /// Create default benchmark datasets
207    fn default_datasets() -> Vec<DatasetConfig> {
208        vec![
209            /// DatasetConfig
210            DatasetConfig {
211                name: "small_balanced_classification".to_string(),
212                data_type: DatasetType::Classification { n_classes: 3 },
213                size: DatasetSize {
214                    n_samples: 100,
215                    n_features: 4,
216                },
217                properties: DatasetProperties {
218                    noise_level: 0.1,
219                    correlation: 0.0,
220                    outlier_fraction: 0.0,
221                    random_state: Some(42),
222                },
223            },
224            /// DatasetConfig
225            DatasetConfig {
226                name: "large_classification".to_string(),
227                data_type: DatasetType::Classification { n_classes: 5 },
228                size: DatasetSize {
229                    n_samples: 1000,
230                    n_features: 20,
231                },
232                properties: DatasetProperties {
233                    noise_level: 0.2,
234                    correlation: 0.1,
235                    outlier_fraction: 0.05,
236                    random_state: Some(42),
237                },
238            },
239            /// DatasetConfig
240            DatasetConfig {
241                name: "imbalanced_classification".to_string(),
242                data_type: DatasetType::Imbalanced {
243                    majority_ratio: 0.9,
244                },
245                size: DatasetSize {
246                    n_samples: 500,
247                    n_features: 10,
248                },
249                properties: DatasetProperties {
250                    noise_level: 0.1,
251                    correlation: 0.0,
252                    outlier_fraction: 0.02,
253                    random_state: Some(42),
254                },
255            },
256            /// DatasetConfig
257            DatasetConfig {
258                name: "small_regression".to_string(),
259                data_type: DatasetType::Regression,
260                size: DatasetSize {
261                    n_samples: 100,
262                    n_features: 5,
263                },
264                properties: DatasetProperties {
265                    noise_level: 0.1,
266                    correlation: 0.0,
267                    outlier_fraction: 0.0,
268                    random_state: Some(42),
269                },
270            },
271            /// DatasetConfig
272            DatasetConfig {
273                name: "large_regression".to_string(),
274                data_type: DatasetType::Regression,
275                size: DatasetSize {
276                    n_samples: 1000,
277                    n_features: 15,
278                },
279                properties: DatasetProperties {
280                    noise_level: 0.2,
281                    correlation: 0.2,
282                    outlier_fraction: 0.05,
283                    random_state: Some(42),
284                },
285            },
286        ]
287    }
288}
289
290impl SklearnBenchmarkFramework {
291    /// Create new benchmark framework with default configuration
292    pub fn new() -> Self {
293        Self {
294            config: BenchmarkConfig::default(),
295        }
296    }
297
298    /// Create new benchmark framework with custom configuration
299    pub fn with_config(config: BenchmarkConfig) -> Self {
300        Self { config }
301    }
302
303    /// Run comprehensive benchmarks for dummy classifiers
304    pub fn benchmark_dummy_classifier(&self) -> Result<Vec<BenchmarkResult>, SklearsError> {
305        let mut results = Vec::new();
306
307        let strategies = vec![
308            ClassifierStrategy::MostFrequent,
309            ClassifierStrategy::Uniform,
310            ClassifierStrategy::Stratified,
311            ClassifierStrategy::Constant,
312            ClassifierStrategy::Prior,
313        ];
314
315        for dataset_config in &self.config.datasets {
316            if let DatasetType::Classification { .. }
317            | DatasetType::Imbalanced { .. }
318            | DatasetType::Multiclass { .. } = dataset_config.data_type
319            {
320                let (X, y) = self.generate_classification_dataset(dataset_config)?;
321
322                for strategy in &strategies {
323                    if let Ok(result) =
324                        self.benchmark_classifier_strategy(&X, &y, strategy.clone(), dataset_config)
325                    {
326                        results.push(result);
327                    }
328                }
329            }
330        }
331
332        Ok(results)
333    }
334
335    /// Run comprehensive benchmarks for dummy regressors
336    pub fn benchmark_dummy_regressor(&self) -> Result<Vec<BenchmarkResult>, SklearsError> {
337        let mut results = Vec::new();
338
339        let strategies = vec![
340            RegressorStrategy::Mean,
341            RegressorStrategy::Median,
342            RegressorStrategy::Quantile(0.25),
343            RegressorStrategy::Quantile(0.75),
344            RegressorStrategy::Constant(0.0),
345        ];
346
347        for dataset_config in &self.config.datasets {
348            if let DatasetType::Regression = dataset_config.data_type {
349                let (X, y) = self.generate_regression_dataset(dataset_config)?;
350
351                for strategy in &strategies {
352                    if let Ok(result) =
353                        self.benchmark_regressor_strategy(&X, &y, *strategy, dataset_config)
354                    {
355                        results.push(result);
356                    }
357                }
358            }
359        }
360
361        Ok(results)
362    }
363
364    /// Benchmark a specific classifier strategy
365    fn benchmark_classifier_strategy(
366        &self,
367        X: &Array2<f64>,
368        y: &Array1<i32>,
369        strategy: ClassifierStrategy,
370        dataset_config: &DatasetConfig,
371    ) -> Result<BenchmarkResult, SklearsError> {
372        let mut total_fit_time = Duration::new(0, 0);
373        let mut total_predict_time = Duration::new(0, 0);
374        let mut predictions_list = Vec::new();
375
376        for run in 0..self.config.n_runs {
377            // Create and fit sklears dummy classifier
378            let mut classifier = DummyClassifier::new(strategy.clone());
379            if let Some(seed) = self.config.random_state {
380                classifier = classifier.with_random_state(seed + run as u64);
381            }
382
383            let start_fit = Instant::now();
384            let fitted_classifier = classifier.fit(X, y)?;
385            let fit_time = start_fit.elapsed();
386            total_fit_time += fit_time;
387
388            let start_predict = Instant::now();
389            let predictions = fitted_classifier.predict(X)?;
390            let predict_time = start_predict.elapsed();
391            total_predict_time += predict_time;
392
393            predictions_list.push(predictions);
394        }
395
396        // Calculate average performance metrics
397        let avg_fit_time = total_fit_time / self.config.n_runs as u32;
398        let avg_predict_time = total_predict_time / self.config.n_runs as u32;
399
400        // Use the predictions from the first run for accuracy comparison
401        let predictions = &predictions_list[0];
402
403        // Calculate accuracy (proportion of correct predictions)
404        let accuracy = Self::calculate_accuracy(y, predictions);
405
406        // Generate reference predictions for comparison
407        let reference_predictions =
408            self.generate_reference_classifier_predictions(X, y, &strategy)?;
409        let reference_accuracy = Self::calculate_accuracy(y, &reference_predictions);
410
411        // Calculate numerical accuracy metrics
412        let numerical_accuracy =
413            self.calculate_classifier_numerical_accuracy(predictions, &reference_predictions)?;
414
415        let accuracy_comparison = AccuracyComparison {
416            sklears_score: accuracy,
417            reference_score: reference_accuracy,
418            absolute_difference: (accuracy - reference_accuracy).abs(),
419            relative_difference: if reference_accuracy != 0.0 {
420                ((accuracy - reference_accuracy) / reference_accuracy).abs()
421            } else {
422                0.0
423            },
424            within_tolerance: (accuracy - reference_accuracy).abs() <= self.config.tolerance,
425            tolerance_used: self.config.tolerance,
426        };
427
428        let performance_metrics = PerformanceMetrics {
429            fit_time_sklears: avg_fit_time,
430            predict_time_sklears: avg_predict_time,
431            fit_time_reference: Duration::from_millis(1), // Mock reference time
432            predict_time_reference: Duration::from_millis(1), // Mock reference time
433            speedup_fit: 1.0,          // Would be calculated with actual reference
434            speedup_predict: 1.0,      // Would be calculated with actual reference
435            memory_usage_sklears: 0,   // Would be measured with profiling
436            memory_usage_reference: 0, // Would be measured with profiling
437        };
438
439        let dataset_info = self.create_classification_dataset_info(dataset_config, X, y);
440
441        Ok(BenchmarkResult {
442            strategy: format!("{:?}", strategy),
443            accuracy_comparison,
444            performance_metrics,
445            numerical_accuracy,
446            dataset_info,
447        })
448    }
449
450    /// Benchmark a specific regressor strategy
451    fn benchmark_regressor_strategy(
452        &self,
453        X: &Array2<f64>,
454        y: &Array1<f64>,
455        strategy: RegressorStrategy,
456        dataset_config: &DatasetConfig,
457    ) -> Result<BenchmarkResult, SklearsError> {
458        let mut total_fit_time = Duration::new(0, 0);
459        let mut total_predict_time = Duration::new(0, 0);
460        let mut predictions_list = Vec::new();
461
462        for run in 0..self.config.n_runs {
463            // Create and fit sklears dummy regressor
464            let mut regressor = DummyRegressor::new(strategy);
465            if let Some(seed) = self.config.random_state {
466                regressor = regressor.with_random_state(seed + run as u64);
467            }
468
469            let start_fit = Instant::now();
470            let fitted_regressor = regressor.fit(X, y)?;
471            let fit_time = start_fit.elapsed();
472            total_fit_time += fit_time;
473
474            let start_predict = Instant::now();
475            let predictions = fitted_regressor.predict(X)?;
476            let predict_time = start_predict.elapsed();
477            total_predict_time += predict_time;
478
479            predictions_list.push(predictions);
480        }
481
482        // Calculate average performance metrics
483        let avg_fit_time = total_fit_time / self.config.n_runs as u32;
484        let avg_predict_time = total_predict_time / self.config.n_runs as u32;
485
486        // Use the predictions from the first run for accuracy comparison
487        let predictions = &predictions_list[0];
488
489        // Calculate R² score for regression
490        let r2_score = Self::calculate_r2_score(y, predictions);
491
492        // Generate reference predictions for comparison
493        let reference_predictions =
494            self.generate_reference_regressor_predictions(X, y, &strategy)?;
495        let reference_r2 = Self::calculate_r2_score(y, &reference_predictions);
496
497        // Calculate numerical accuracy metrics
498        let numerical_accuracy =
499            self.calculate_regressor_numerical_accuracy(predictions, &reference_predictions)?;
500
501        let accuracy_comparison = AccuracyComparison {
502            sklears_score: r2_score,
503            reference_score: reference_r2,
504            absolute_difference: (r2_score - reference_r2).abs(),
505            relative_difference: if reference_r2 != 0.0 {
506                ((r2_score - reference_r2) / reference_r2).abs()
507            } else {
508                0.0
509            },
510            within_tolerance: (r2_score - reference_r2).abs() <= self.config.tolerance,
511            tolerance_used: self.config.tolerance,
512        };
513
514        let performance_metrics = PerformanceMetrics {
515            fit_time_sklears: avg_fit_time,
516            predict_time_sklears: avg_predict_time,
517            fit_time_reference: Duration::from_millis(1), // Mock reference time
518            predict_time_reference: Duration::from_millis(1), // Mock reference time
519            speedup_fit: 1.0,          // Would be calculated with actual reference
520            speedup_predict: 1.0,      // Would be calculated with actual reference
521            memory_usage_sklears: 0,   // Would be measured with profiling
522            memory_usage_reference: 0, // Would be measured with profiling
523        };
524
525        let dataset_info = self.create_regression_dataset_info(dataset_config, X, y);
526
527        Ok(BenchmarkResult {
528            strategy: format!("{:?}", strategy),
529            accuracy_comparison,
530            performance_metrics,
531            numerical_accuracy,
532            dataset_info,
533        })
534    }
535
536    /// Generate synthetic classification dataset
537    fn generate_classification_dataset(
538        &self,
539        config: &DatasetConfig,
540    ) -> Result<(Array2<f64>, Array1<i32>), SklearsError> {
541        let mut rng = if let Some(seed) = config.properties.random_state {
542            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
543        } else {
544            scirs2_core::random::rngs::StdRng::seed_from_u64(0)
545        };
546
547        let n_samples = config.size.n_samples;
548        let n_features = config.size.n_features;
549
550        let n_classes = match config.data_type {
551            DatasetType::Classification { n_classes } => n_classes,
552            DatasetType::Multiclass { n_classes } => n_classes,
553            DatasetType::Imbalanced { .. } => 2, // Binary for imbalanced
554            _ => {
555                return Err(SklearsError::InvalidParameter {
556                    name: "dataset_type".to_string(),
557                    reason: "Invalid dataset type for classification".to_string(),
558                })
559            }
560        };
561
562        // Generate features
563        let mut X = Array2::<f64>::zeros((n_samples, n_features));
564        for i in 0..n_samples {
565            for j in 0..n_features {
566                X[[i, j]] = rng.gen_range(-1.0..1.0);
567            }
568        }
569
570        // Add noise
571        if config.properties.noise_level > 0.0 {
572            for i in 0..n_samples {
573                for j in 0..n_features {
574                    let noise = rng
575                        .gen_range(-config.properties.noise_level..config.properties.noise_level);
576                    X[[i, j]] += noise;
577                }
578            }
579        }
580
581        // Generate labels
582        let mut y = Array1::<i32>::zeros(n_samples);
583        match config.data_type {
584            DatasetType::Imbalanced { majority_ratio } => {
585                let n_majority = (n_samples as f64 * majority_ratio) as usize;
586                for i in 0..n_samples {
587                    y[i] = if i < n_majority { 0 } else { 1 };
588                }
589                // Shuffle labels
590                for i in 0..n_samples {
591                    let j = rng.gen_range(0..n_samples);
592                    let temp = y[i];
593                    y[i] = y[j];
594                    y[j] = temp;
595                }
596            }
597            _ => {
598                for i in 0..n_samples {
599                    y[i] = rng.gen_range(0..n_classes as i32);
600                }
601            }
602        }
603
604        Ok((X, y))
605    }
606
607    /// Generate synthetic regression dataset
608    fn generate_regression_dataset(
609        &self,
610        config: &DatasetConfig,
611    ) -> Result<(Array2<f64>, Array1<f64>), SklearsError> {
612        let mut rng = if let Some(seed) = config.properties.random_state {
613            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
614        } else {
615            scirs2_core::random::rngs::StdRng::seed_from_u64(0)
616        };
617
618        let n_samples = config.size.n_samples;
619        let n_features = config.size.n_features;
620
621        // Generate features
622        let mut X = Array2::<f64>::zeros((n_samples, n_features));
623        for i in 0..n_samples {
624            for j in 0..n_features {
625                X[[i, j]] = rng.gen_range(-2.0..2.0);
626            }
627        }
628
629        // Generate targets with some relationship to features
630        let mut y = Array1::<f64>::zeros(n_samples);
631        for i in 0..n_samples {
632            let mut target = 0.0;
633            for j in 0..n_features.min(3) {
634                // Use first 3 features for target
635                target += X[[i, j]] * (j + 1) as f64 * 0.3;
636            }
637
638            // Add noise
639            if config.properties.noise_level > 0.0 {
640                let noise =
641                    rng.gen_range(-config.properties.noise_level..config.properties.noise_level);
642                target += noise;
643            }
644
645            y[i] = target;
646        }
647
648        // Add outliers
649        if config.properties.outlier_fraction > 0.0 {
650            let n_outliers = (n_samples as f64 * config.properties.outlier_fraction) as usize;
651            for _ in 0..n_outliers {
652                let idx = rng.gen_range(0..n_samples);
653                y[idx] *= rng.gen_range(3.0..10.0); // Make it an outlier
654            }
655        }
656
657        Ok((X, y))
658    }
659
660    /// Generate reference classifier predictions (simplified simulation)
661    fn generate_reference_classifier_predictions(
662        &self,
663        X: &Array2<f64>,
664        y: &Array1<i32>,
665        strategy: &ClassifierStrategy,
666    ) -> Result<Array1<i32>, SklearsError> {
667        let n_samples = X.nrows();
668        let mut predictions = Array1::<i32>::zeros(n_samples);
669
670        match strategy {
671            ClassifierStrategy::MostFrequent => {
672                // Find most frequent class
673                let mut class_counts = HashMap::new();
674                for &label in y {
675                    *class_counts.entry(label).or_insert(0) += 1;
676                }
677                let most_frequent = *class_counts
678                    .iter()
679                    .max_by_key(|(_, &count)| count)
680                    .unwrap()
681                    .0;
682                predictions.fill(most_frequent);
683            }
684            ClassifierStrategy::Constant => {
685                // Use first class as constant value (simplified)
686                predictions.fill(y[0]);
687            }
688            _ => {
689                // For other strategies, use a simplified implementation
690                predictions.fill(y[0]); // Use first label as fallback
691            }
692        }
693
694        Ok(predictions)
695    }
696
697    /// Generate reference regressor predictions (simplified simulation)
698    fn generate_reference_regressor_predictions(
699        &self,
700        X: &Array2<f64>,
701        y: &Array1<f64>,
702        strategy: &RegressorStrategy,
703    ) -> Result<Array1<f64>, SklearsError> {
704        let n_samples = X.nrows();
705        let mut predictions = Array1::<f64>::zeros(n_samples);
706
707        match strategy {
708            RegressorStrategy::Mean => {
709                let mean = y.mean().unwrap_or(0.0);
710                predictions.fill(mean);
711            }
712            RegressorStrategy::Median => {
713                let mut sorted_y = y.to_vec();
714                sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
715                let median = if sorted_y.len() % 2 == 0 {
716                    (sorted_y[sorted_y.len() / 2 - 1] + sorted_y[sorted_y.len() / 2]) / 2.0
717                } else {
718                    sorted_y[sorted_y.len() / 2]
719                };
720                predictions.fill(median);
721            }
722            RegressorStrategy::Constant(value) => {
723                predictions.fill(*value);
724            }
725            RegressorStrategy::Quantile(q) => {
726                let mut sorted_y = y.to_vec();
727                sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
728                let index = (*q * (sorted_y.len() - 1) as f64) as usize;
729                let quantile = sorted_y[index.min(sorted_y.len() - 1)];
730                predictions.fill(quantile);
731            }
732            _ => {
733                // Fallback to mean
734                let mean = y.mean().unwrap_or(0.0);
735                predictions.fill(mean);
736            }
737        }
738
739        Ok(predictions)
740    }
741
742    /// Calculate accuracy for classification
743    fn calculate_accuracy(y_true: &Array1<i32>, y_pred: &Array1<i32>) -> f64 {
744        let n_samples = y_true.len();
745        if n_samples == 0 {
746            return 0.0;
747        }
748
749        let correct = y_true
750            .iter()
751            .zip(y_pred.iter())
752            .filter(|(&true_val, &pred_val)| true_val == pred_val)
753            .count();
754        correct as f64 / n_samples as f64
755    }
756
757    /// Calculate R² score for regression
758    fn calculate_r2_score(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
759        let n_samples = y_true.len();
760        if n_samples == 0 {
761            return 0.0;
762        }
763
764        let y_mean = y_true.mean().unwrap_or(0.0);
765
766        let ss_res: f64 = y_true
767            .iter()
768            .zip(y_pred.iter())
769            .map(|(true_val, pred_val)| (true_val - pred_val).powi(2))
770            .sum();
771
772        let ss_tot: f64 = y_true.iter().map(|val| (val - y_mean).powi(2)).sum();
773
774        if ss_tot == 0.0 {
775            return 0.0;
776        }
777
778        1.0 - (ss_res / ss_tot)
779    }
780
781    /// Calculate numerical accuracy for classifier predictions
782    fn calculate_classifier_numerical_accuracy(
783        &self,
784        predictions: &Array1<i32>,
785        reference: &Array1<i32>,
786    ) -> Result<NumericalAccuracy, SklearsError> {
787        let n_samples = predictions.len();
788        if n_samples != reference.len() {
789            return Err(SklearsError::InvalidParameter {
790                name: "predictions".to_string(),
791                reason: "Prediction arrays must have same length".to_string(),
792            });
793        }
794
795        let mse = predictions
796            .iter()
797            .zip(reference.iter())
798            .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).powi(2))
799            .sum::<f64>()
800            / n_samples as f64;
801
802        let mae = predictions
803            .iter()
804            .zip(reference.iter())
805            .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).abs())
806            .sum::<f64>()
807            / n_samples as f64;
808
809        let max_error = predictions
810            .iter()
811            .zip(reference.iter())
812            .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).abs())
813            .fold(0.0, f64::max);
814
815        // Calculate correlation (treat as continuous for correlation)
816        let pred_mean = predictions.iter().map(|&x| x as f64).sum::<f64>() / n_samples as f64;
817        let ref_mean = reference.iter().map(|&x| x as f64).sum::<f64>() / n_samples as f64;
818
819        let numerator: f64 = predictions
820            .iter()
821            .zip(reference.iter())
822            .map(|(pred, ref_val)| (*pred as f64 - pred_mean) * (*ref_val as f64 - ref_mean))
823            .sum();
824
825        let pred_var: f64 = predictions
826            .iter()
827            .map(|&x| (x as f64 - pred_mean).powi(2))
828            .sum();
829
830        let ref_var: f64 = reference
831            .iter()
832            .map(|&x| (x as f64 - ref_mean).powi(2))
833            .sum();
834
835        let correlation = if pred_var > 0.0 && ref_var > 0.0 {
836            numerator / (pred_var * ref_var).sqrt()
837        } else {
838            1.0 // Perfect correlation if no variance
839        };
840
841        Ok(NumericalAccuracy {
842            prediction_mse: mse,
843            prediction_mae: mae,
844            max_absolute_error: max_error,
845            correlation,
846            reproducibility_check: true, // Would test with multiple runs
847        })
848    }
849
850    /// Calculate numerical accuracy for regressor predictions
851    fn calculate_regressor_numerical_accuracy(
852        &self,
853        predictions: &Array1<f64>,
854        reference: &Array1<f64>,
855    ) -> Result<NumericalAccuracy, SklearsError> {
856        let n_samples = predictions.len();
857        if n_samples != reference.len() {
858            return Err(SklearsError::InvalidParameter {
859                name: "predictions".to_string(),
860                reason: "Prediction arrays must have same length".to_string(),
861            });
862        }
863
864        let mse = predictions
865            .iter()
866            .zip(reference.iter())
867            .map(|(pred, ref_val)| (pred - ref_val).powi(2))
868            .sum::<f64>()
869            / n_samples as f64;
870
871        let mae = predictions
872            .iter()
873            .zip(reference.iter())
874            .map(|(pred, ref_val)| (pred - ref_val).abs())
875            .sum::<f64>()
876            / n_samples as f64;
877
878        let max_error = predictions
879            .iter()
880            .zip(reference.iter())
881            .map(|(pred, ref_val)| (pred - ref_val).abs())
882            .fold(0.0, f64::max);
883
884        // Calculate correlation
885        let pred_mean = predictions.mean().unwrap_or(0.0);
886        let ref_mean = reference.mean().unwrap_or(0.0);
887
888        let numerator: f64 = predictions
889            .iter()
890            .zip(reference.iter())
891            .map(|(pred, ref_val)| (pred - pred_mean) * (ref_val - ref_mean))
892            .sum();
893
894        let pred_var: f64 = predictions.iter().map(|x| (x - pred_mean).powi(2)).sum();
895
896        let ref_var: f64 = reference.iter().map(|x| (x - ref_mean).powi(2)).sum();
897
898        let correlation = if pred_var > 0.0 && ref_var > 0.0 {
899            numerator / (pred_var * ref_var).sqrt()
900        } else {
901            1.0 // Perfect correlation if no variance
902        };
903
904        Ok(NumericalAccuracy {
905            prediction_mse: mse,
906            prediction_mae: mae,
907            max_absolute_error: max_error,
908            correlation,
909            reproducibility_check: true, // Would test with multiple runs
910        })
911    }
912
913    /// Create dataset info for classification
914    fn create_classification_dataset_info(
915        &self,
916        config: &DatasetConfig,
917        X: &Array2<f64>,
918        y: &Array1<i32>,
919    ) -> DatasetInfo {
920        let mut class_distribution = HashMap::new();
921        for &label in y {
922            *class_distribution.entry(label).or_insert(0) += 1;
923        }
924
925        let n_classes = class_distribution.len();
926
927        /// DatasetInfo
928        DatasetInfo {
929            name: config.name.clone(),
930            n_samples: X.nrows(),
931            n_features: X.ncols(),
932            n_classes: Some(n_classes),
933            class_distribution: Some(class_distribution),
934            target_statistics: None,
935        }
936    }
937
938    /// Create dataset info for regression
939    fn create_regression_dataset_info(
940        &self,
941        config: &DatasetConfig,
942        X: &Array2<f64>,
943        y: &Array1<f64>,
944    ) -> DatasetInfo {
945        let mean = y.mean().unwrap_or(0.0);
946        let variance = y.iter().map(|val| (val - mean).powi(2)).sum::<f64>() / y.len() as f64;
947        let std = variance.sqrt();
948        let min = y.iter().fold(f64::INFINITY, |a, &b| a.min(b));
949        let max = y.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
950
951        // Calculate skewness and kurtosis (simplified)
952        let skewness = y
953            .iter()
954            .map(|val| ((val - mean) / std).powi(3))
955            .sum::<f64>()
956            / y.len() as f64;
957        let kurtosis = y
958            .iter()
959            .map(|val| ((val - mean) / std).powi(4))
960            .sum::<f64>()
961            / y.len() as f64;
962
963        /// DatasetInfo
964        DatasetInfo {
965            name: config.name.clone(),
966            n_samples: X.nrows(),
967            n_features: X.ncols(),
968            n_classes: None,
969            class_distribution: None,
970            target_statistics: Some(TargetStatistics {
971                mean,
972                std,
973                min,
974                max,
975                skewness,
976                kurtosis,
977            }),
978        }
979    }
980
981    /// Generate comprehensive benchmark report
982    pub fn generate_report(&self, results: &[BenchmarkResult]) -> String {
983        let mut report = String::new();
984
985        report.push_str("# Sklearn Benchmark Report\n\n");
986        report.push_str(&format!("Generated {} results\n\n", results.len()));
987
988        report.push_str("## Summary\n\n");
989
990        let total_within_tolerance = results
991            .iter()
992            .filter(|r| r.accuracy_comparison.within_tolerance)
993            .count();
994        let tolerance_rate = total_within_tolerance as f64 / results.len() as f64 * 100.0;
995
996        report.push_str(&format!(
997            "- **Accuracy within tolerance**: {}/{} ({:.1}%)\n",
998            total_within_tolerance,
999            results.len(),
1000            tolerance_rate
1001        ));
1002
1003        let avg_speedup_fit = results
1004            .iter()
1005            .map(|r| r.performance_metrics.speedup_fit)
1006            .sum::<f64>()
1007            / results.len() as f64;
1008
1009        let avg_speedup_predict = results
1010            .iter()
1011            .map(|r| r.performance_metrics.speedup_predict)
1012            .sum::<f64>()
1013            / results.len() as f64;
1014
1015        report.push_str(&format!(
1016            "- **Average fit speedup**: {:.2}x\n",
1017            avg_speedup_fit
1018        ));
1019        report.push_str(&format!(
1020            "- **Average predict speedup**: {:.2}x\n",
1021            avg_speedup_predict
1022        ));
1023
1024        report.push_str("\n## Detailed Results\n\n");
1025
1026        for result in results {
1027            report.push_str(&format!(
1028                "### {} on {}\n\n",
1029                result.strategy, result.dataset_info.name
1030            ));
1031
1032            report.push_str("**Accuracy Comparison:**\n");
1033            report.push_str(&format!(
1034                "- Sklears score: {:.6}\n",
1035                result.accuracy_comparison.sklears_score
1036            ));
1037            report.push_str(&format!(
1038                "- Reference score: {:.6}\n",
1039                result.accuracy_comparison.reference_score
1040            ));
1041            report.push_str(&format!(
1042                "- Absolute difference: {:.6}\n",
1043                result.accuracy_comparison.absolute_difference
1044            ));
1045            report.push_str(&format!(
1046                "- Within tolerance: {}\n",
1047                result.accuracy_comparison.within_tolerance
1048            ));
1049
1050            report.push_str("\n**Performance Metrics:**\n");
1051            report.push_str(&format!(
1052                "- Fit time: {:?}\n",
1053                result.performance_metrics.fit_time_sklears
1054            ));
1055            report.push_str(&format!(
1056                "- Predict time: {:?}\n",
1057                result.performance_metrics.predict_time_sklears
1058            ));
1059
1060            report.push_str("\n**Numerical Accuracy:**\n");
1061            report.push_str(&format!(
1062                "- MSE: {:.6}\n",
1063                result.numerical_accuracy.prediction_mse
1064            ));
1065            report.push_str(&format!(
1066                "- MAE: {:.6}\n",
1067                result.numerical_accuracy.prediction_mae
1068            ));
1069            report.push_str(&format!(
1070                "- Correlation: {:.6}\n",
1071                result.numerical_accuracy.correlation
1072            ));
1073
1074            report.push_str("\n---\n\n");
1075        }
1076
1077        report
1078    }
1079}
1080
1081impl Default for SklearnBenchmarkFramework {
1082    fn default() -> Self {
1083        Self::new()
1084    }
1085}
1086
1087#[allow(non_snake_case)]
1088#[cfg(test)]
1089mod tests {
1090    use super::*;
1091
1092    #[test]
1093    fn test_benchmark_framework_creation() {
1094        let framework = SklearnBenchmarkFramework::new();
1095        assert_eq!(framework.config.tolerance, 1e-10);
1096        assert_eq!(framework.config.n_runs, 5);
1097    }
1098
1099    #[test]
1100    fn test_synthetic_dataset_generation() {
1101        let framework = SklearnBenchmarkFramework::new();
1102        let config = DatasetConfig {
1103            name: "test".to_string(),
1104            data_type: DatasetType::Classification { n_classes: 3 },
1105            size: DatasetSize {
1106                n_samples: 100,
1107                n_features: 4,
1108            },
1109            properties: DatasetProperties {
1110                noise_level: 0.1,
1111                correlation: 0.0,
1112                outlier_fraction: 0.0,
1113                random_state: Some(42),
1114            },
1115        };
1116
1117        let (X, y) = framework.generate_classification_dataset(&config).unwrap();
1118        assert_eq!(X.nrows(), 100);
1119        assert_eq!(X.ncols(), 4);
1120        assert_eq!(y.len(), 100);
1121
1122        // Check that labels are in valid range
1123        for &label in &y {
1124            assert!(label >= 0 && label < 3);
1125        }
1126    }
1127
1128    #[test]
1129    fn test_accuracy_calculation() {
1130        let y_true = Array1::from(vec![0, 1, 2, 1, 0]);
1131        let y_pred = Array1::from(vec![0, 1, 1, 1, 0]);
1132
1133        let accuracy = SklearnBenchmarkFramework::calculate_accuracy(&y_true, &y_pred);
1134        assert!((accuracy - 0.8).abs() < 1e-10); // 4/5 correct
1135    }
1136
1137    #[test]
1138    fn test_r2_score_calculation() {
1139        let y_true = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1140        let y_pred = Array1::from(vec![1.1, 1.9, 3.1, 3.9, 5.1]);
1141
1142        let r2 = SklearnBenchmarkFramework::calculate_r2_score(&y_true, &y_pred);
1143        assert!(r2 > 0.9); // Should be very high correlation
1144    }
1145
1146    #[test]
1147    fn test_benchmark_classifier() {
1148        let framework = SklearnBenchmarkFramework::new();
1149        let results = framework.benchmark_dummy_classifier().unwrap();
1150
1151        // Should have results for classification datasets
1152        assert!(!results.is_empty());
1153
1154        // Check that all results have valid data
1155        for result in &results {
1156            assert!(!result.strategy.is_empty());
1157            assert!(result.accuracy_comparison.sklears_score >= 0.0);
1158            assert!(result.accuracy_comparison.sklears_score <= 1.0);
1159        }
1160    }
1161
1162    #[test]
1163    fn test_benchmark_regressor() {
1164        let framework = SklearnBenchmarkFramework::new();
1165        let results = framework.benchmark_dummy_regressor().unwrap();
1166
1167        // Should have results for regression datasets
1168        assert!(!results.is_empty());
1169
1170        // Check that all results have valid data
1171        for result in &results {
1172            assert!(!result.strategy.is_empty());
1173            // R² can be negative for very bad predictions, so just check it's finite
1174            assert!(result.accuracy_comparison.sklears_score.is_finite());
1175        }
1176    }
1177
1178    #[test]
1179    fn test_report_generation() {
1180        let framework = SklearnBenchmarkFramework::new();
1181        let results = framework.benchmark_dummy_classifier().unwrap();
1182
1183        let report = framework.generate_report(&results);
1184        assert!(report.contains("Sklearn Benchmark Report"));
1185        assert!(report.contains("Summary"));
1186        assert!(report.contains("Detailed Results"));
1187    }
1188}