1use crate::{ClassifierStrategy, DummyClassifier, DummyRegressor, RegressorStrategy};
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::{Rng, SeedableRng};
13use sklears_core::{error::SklearsError, traits::Estimator, traits::Fit, traits::Predict};
14use std::collections::HashMap;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct BenchmarkResult {
20 pub strategy: String,
22 pub accuracy_comparison: AccuracyComparison,
24 pub performance_metrics: PerformanceMetrics,
26 pub numerical_accuracy: NumericalAccuracy,
28 pub dataset_info: DatasetInfo,
30}
31
32#[derive(Debug, Clone)]
34pub struct AccuracyComparison {
35 pub sklears_score: f64,
37 pub reference_score: f64,
39 pub absolute_difference: f64,
41 pub relative_difference: f64,
43 pub within_tolerance: bool,
45 pub tolerance_used: f64,
47}
48
49#[derive(Debug, Clone)]
51pub struct PerformanceMetrics {
52 pub fit_time_sklears: Duration,
54 pub predict_time_sklears: Duration,
56 pub fit_time_reference: Duration,
58 pub predict_time_reference: Duration,
60 pub speedup_fit: f64,
62 pub speedup_predict: f64,
64 pub memory_usage_sklears: usize,
66 pub memory_usage_reference: usize,
68}
69
70#[derive(Debug, Clone)]
72pub struct NumericalAccuracy {
73 pub prediction_mse: f64,
75 pub prediction_mae: f64,
77 pub max_absolute_error: f64,
79 pub correlation: f64,
81 pub reproducibility_check: bool,
83}
84
85#[derive(Debug, Clone)]
87pub struct DatasetInfo {
88 pub name: String,
90 pub n_samples: usize,
92 pub n_features: usize,
94 pub n_classes: Option<usize>,
96 pub class_distribution: Option<HashMap<i32, usize>>,
98 pub target_statistics: Option<TargetStatistics>,
100}
101
102#[derive(Debug, Clone)]
104pub struct TargetStatistics {
105 pub mean: f64,
107 pub std: f64,
109 pub min: f64,
111 pub max: f64,
113 pub skewness: f64,
115 pub kurtosis: f64,
117}
118
119#[derive(Debug, Clone)]
121pub struct BenchmarkConfig {
122 pub tolerance: f64,
124 pub n_runs: usize,
126 pub random_state: Option<u64>,
128 pub include_performance: bool,
130 pub include_memory: bool,
132 pub test_reproducibility: bool,
134 pub datasets: Vec<DatasetConfig>,
136}
137
138#[derive(Debug, Clone)]
140pub struct DatasetConfig {
141 pub name: String,
143 pub data_type: DatasetType,
145 pub size: DatasetSize,
147 pub properties: DatasetProperties,
149}
150
151#[derive(Debug, Clone)]
153pub enum DatasetType {
154 Classification { n_classes: usize },
156 Regression,
158 Multiclass { n_classes: usize },
160 Imbalanced { majority_ratio: f64 },
162}
163
164#[derive(Debug, Clone)]
166pub struct DatasetSize {
167 pub n_samples: usize,
169 pub n_features: usize,
171}
172
173#[derive(Debug, Clone)]
175pub struct DatasetProperties {
176 pub noise_level: f64,
178 pub correlation: f64,
180 pub outlier_fraction: f64,
182 pub random_state: Option<u64>,
184}
185
186pub struct SklearnBenchmarkFramework {
188 config: BenchmarkConfig,
189}
190
191impl Default for BenchmarkConfig {
192 fn default() -> Self {
193 Self {
194 tolerance: 1e-10,
195 n_runs: 5,
196 random_state: Some(42),
197 include_performance: true,
198 include_memory: false, test_reproducibility: true,
200 datasets: Self::default_datasets(),
201 }
202 }
203}
204
205impl BenchmarkConfig {
206 fn default_datasets() -> Vec<DatasetConfig> {
208 vec![
209 DatasetConfig {
211 name: "small_balanced_classification".to_string(),
212 data_type: DatasetType::Classification { n_classes: 3 },
213 size: DatasetSize {
214 n_samples: 100,
215 n_features: 4,
216 },
217 properties: DatasetProperties {
218 noise_level: 0.1,
219 correlation: 0.0,
220 outlier_fraction: 0.0,
221 random_state: Some(42),
222 },
223 },
224 DatasetConfig {
226 name: "large_classification".to_string(),
227 data_type: DatasetType::Classification { n_classes: 5 },
228 size: DatasetSize {
229 n_samples: 1000,
230 n_features: 20,
231 },
232 properties: DatasetProperties {
233 noise_level: 0.2,
234 correlation: 0.1,
235 outlier_fraction: 0.05,
236 random_state: Some(42),
237 },
238 },
239 DatasetConfig {
241 name: "imbalanced_classification".to_string(),
242 data_type: DatasetType::Imbalanced {
243 majority_ratio: 0.9,
244 },
245 size: DatasetSize {
246 n_samples: 500,
247 n_features: 10,
248 },
249 properties: DatasetProperties {
250 noise_level: 0.1,
251 correlation: 0.0,
252 outlier_fraction: 0.02,
253 random_state: Some(42),
254 },
255 },
256 DatasetConfig {
258 name: "small_regression".to_string(),
259 data_type: DatasetType::Regression,
260 size: DatasetSize {
261 n_samples: 100,
262 n_features: 5,
263 },
264 properties: DatasetProperties {
265 noise_level: 0.1,
266 correlation: 0.0,
267 outlier_fraction: 0.0,
268 random_state: Some(42),
269 },
270 },
271 DatasetConfig {
273 name: "large_regression".to_string(),
274 data_type: DatasetType::Regression,
275 size: DatasetSize {
276 n_samples: 1000,
277 n_features: 15,
278 },
279 properties: DatasetProperties {
280 noise_level: 0.2,
281 correlation: 0.2,
282 outlier_fraction: 0.05,
283 random_state: Some(42),
284 },
285 },
286 ]
287 }
288}
289
290impl SklearnBenchmarkFramework {
291 pub fn new() -> Self {
293 Self {
294 config: BenchmarkConfig::default(),
295 }
296 }
297
298 pub fn with_config(config: BenchmarkConfig) -> Self {
300 Self { config }
301 }
302
303 pub fn benchmark_dummy_classifier(&self) -> Result<Vec<BenchmarkResult>, SklearsError> {
305 let mut results = Vec::new();
306
307 let strategies = vec![
308 ClassifierStrategy::MostFrequent,
309 ClassifierStrategy::Uniform,
310 ClassifierStrategy::Stratified,
311 ClassifierStrategy::Constant,
312 ClassifierStrategy::Prior,
313 ];
314
315 for dataset_config in &self.config.datasets {
316 if let DatasetType::Classification { .. }
317 | DatasetType::Imbalanced { .. }
318 | DatasetType::Multiclass { .. } = dataset_config.data_type
319 {
320 let (X, y) = self.generate_classification_dataset(dataset_config)?;
321
322 for strategy in &strategies {
323 if let Ok(result) =
324 self.benchmark_classifier_strategy(&X, &y, strategy.clone(), dataset_config)
325 {
326 results.push(result);
327 }
328 }
329 }
330 }
331
332 Ok(results)
333 }
334
335 pub fn benchmark_dummy_regressor(&self) -> Result<Vec<BenchmarkResult>, SklearsError> {
337 let mut results = Vec::new();
338
339 let strategies = vec![
340 RegressorStrategy::Mean,
341 RegressorStrategy::Median,
342 RegressorStrategy::Quantile(0.25),
343 RegressorStrategy::Quantile(0.75),
344 RegressorStrategy::Constant(0.0),
345 ];
346
347 for dataset_config in &self.config.datasets {
348 if let DatasetType::Regression = dataset_config.data_type {
349 let (X, y) = self.generate_regression_dataset(dataset_config)?;
350
351 for strategy in &strategies {
352 if let Ok(result) =
353 self.benchmark_regressor_strategy(&X, &y, *strategy, dataset_config)
354 {
355 results.push(result);
356 }
357 }
358 }
359 }
360
361 Ok(results)
362 }
363
364 fn benchmark_classifier_strategy(
366 &self,
367 X: &Array2<f64>,
368 y: &Array1<i32>,
369 strategy: ClassifierStrategy,
370 dataset_config: &DatasetConfig,
371 ) -> Result<BenchmarkResult, SklearsError> {
372 let mut total_fit_time = Duration::new(0, 0);
373 let mut total_predict_time = Duration::new(0, 0);
374 let mut predictions_list = Vec::new();
375
376 for run in 0..self.config.n_runs {
377 let mut classifier = DummyClassifier::new(strategy.clone());
379 if let Some(seed) = self.config.random_state {
380 classifier = classifier.with_random_state(seed + run as u64);
381 }
382
383 let start_fit = Instant::now();
384 let fitted_classifier = classifier.fit(X, y)?;
385 let fit_time = start_fit.elapsed();
386 total_fit_time += fit_time;
387
388 let start_predict = Instant::now();
389 let predictions = fitted_classifier.predict(X)?;
390 let predict_time = start_predict.elapsed();
391 total_predict_time += predict_time;
392
393 predictions_list.push(predictions);
394 }
395
396 let avg_fit_time = total_fit_time / self.config.n_runs as u32;
398 let avg_predict_time = total_predict_time / self.config.n_runs as u32;
399
400 let predictions = &predictions_list[0];
402
403 let accuracy = Self::calculate_accuracy(y, predictions);
405
406 let reference_predictions =
408 self.generate_reference_classifier_predictions(X, y, &strategy)?;
409 let reference_accuracy = Self::calculate_accuracy(y, &reference_predictions);
410
411 let numerical_accuracy =
413 self.calculate_classifier_numerical_accuracy(predictions, &reference_predictions)?;
414
415 let accuracy_comparison = AccuracyComparison {
416 sklears_score: accuracy,
417 reference_score: reference_accuracy,
418 absolute_difference: (accuracy - reference_accuracy).abs(),
419 relative_difference: if reference_accuracy != 0.0 {
420 ((accuracy - reference_accuracy) / reference_accuracy).abs()
421 } else {
422 0.0
423 },
424 within_tolerance: (accuracy - reference_accuracy).abs() <= self.config.tolerance,
425 tolerance_used: self.config.tolerance,
426 };
427
428 let performance_metrics = PerformanceMetrics {
429 fit_time_sklears: avg_fit_time,
430 predict_time_sklears: avg_predict_time,
431 fit_time_reference: Duration::from_millis(1), predict_time_reference: Duration::from_millis(1), speedup_fit: 1.0, speedup_predict: 1.0, memory_usage_sklears: 0, memory_usage_reference: 0, };
438
439 let dataset_info = self.create_classification_dataset_info(dataset_config, X, y);
440
441 Ok(BenchmarkResult {
442 strategy: format!("{:?}", strategy),
443 accuracy_comparison,
444 performance_metrics,
445 numerical_accuracy,
446 dataset_info,
447 })
448 }
449
450 fn benchmark_regressor_strategy(
452 &self,
453 X: &Array2<f64>,
454 y: &Array1<f64>,
455 strategy: RegressorStrategy,
456 dataset_config: &DatasetConfig,
457 ) -> Result<BenchmarkResult, SklearsError> {
458 let mut total_fit_time = Duration::new(0, 0);
459 let mut total_predict_time = Duration::new(0, 0);
460 let mut predictions_list = Vec::new();
461
462 for run in 0..self.config.n_runs {
463 let mut regressor = DummyRegressor::new(strategy);
465 if let Some(seed) = self.config.random_state {
466 regressor = regressor.with_random_state(seed + run as u64);
467 }
468
469 let start_fit = Instant::now();
470 let fitted_regressor = regressor.fit(X, y)?;
471 let fit_time = start_fit.elapsed();
472 total_fit_time += fit_time;
473
474 let start_predict = Instant::now();
475 let predictions = fitted_regressor.predict(X)?;
476 let predict_time = start_predict.elapsed();
477 total_predict_time += predict_time;
478
479 predictions_list.push(predictions);
480 }
481
482 let avg_fit_time = total_fit_time / self.config.n_runs as u32;
484 let avg_predict_time = total_predict_time / self.config.n_runs as u32;
485
486 let predictions = &predictions_list[0];
488
489 let r2_score = Self::calculate_r2_score(y, predictions);
491
492 let reference_predictions =
494 self.generate_reference_regressor_predictions(X, y, &strategy)?;
495 let reference_r2 = Self::calculate_r2_score(y, &reference_predictions);
496
497 let numerical_accuracy =
499 self.calculate_regressor_numerical_accuracy(predictions, &reference_predictions)?;
500
501 let accuracy_comparison = AccuracyComparison {
502 sklears_score: r2_score,
503 reference_score: reference_r2,
504 absolute_difference: (r2_score - reference_r2).abs(),
505 relative_difference: if reference_r2 != 0.0 {
506 ((r2_score - reference_r2) / reference_r2).abs()
507 } else {
508 0.0
509 },
510 within_tolerance: (r2_score - reference_r2).abs() <= self.config.tolerance,
511 tolerance_used: self.config.tolerance,
512 };
513
514 let performance_metrics = PerformanceMetrics {
515 fit_time_sklears: avg_fit_time,
516 predict_time_sklears: avg_predict_time,
517 fit_time_reference: Duration::from_millis(1), predict_time_reference: Duration::from_millis(1), speedup_fit: 1.0, speedup_predict: 1.0, memory_usage_sklears: 0, memory_usage_reference: 0, };
524
525 let dataset_info = self.create_regression_dataset_info(dataset_config, X, y);
526
527 Ok(BenchmarkResult {
528 strategy: format!("{:?}", strategy),
529 accuracy_comparison,
530 performance_metrics,
531 numerical_accuracy,
532 dataset_info,
533 })
534 }
535
536 fn generate_classification_dataset(
538 &self,
539 config: &DatasetConfig,
540 ) -> Result<(Array2<f64>, Array1<i32>), SklearsError> {
541 let mut rng = if let Some(seed) = config.properties.random_state {
542 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
543 } else {
544 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
545 };
546
547 let n_samples = config.size.n_samples;
548 let n_features = config.size.n_features;
549
550 let n_classes = match config.data_type {
551 DatasetType::Classification { n_classes } => n_classes,
552 DatasetType::Multiclass { n_classes } => n_classes,
553 DatasetType::Imbalanced { .. } => 2, _ => {
555 return Err(SklearsError::InvalidParameter {
556 name: "dataset_type".to_string(),
557 reason: "Invalid dataset type for classification".to_string(),
558 })
559 }
560 };
561
562 let mut X = Array2::<f64>::zeros((n_samples, n_features));
564 for i in 0..n_samples {
565 for j in 0..n_features {
566 X[[i, j]] = rng.gen_range(-1.0..1.0);
567 }
568 }
569
570 if config.properties.noise_level > 0.0 {
572 for i in 0..n_samples {
573 for j in 0..n_features {
574 let noise = rng
575 .gen_range(-config.properties.noise_level..config.properties.noise_level);
576 X[[i, j]] += noise;
577 }
578 }
579 }
580
581 let mut y = Array1::<i32>::zeros(n_samples);
583 match config.data_type {
584 DatasetType::Imbalanced { majority_ratio } => {
585 let n_majority = (n_samples as f64 * majority_ratio) as usize;
586 for i in 0..n_samples {
587 y[i] = if i < n_majority { 0 } else { 1 };
588 }
589 for i in 0..n_samples {
591 let j = rng.gen_range(0..n_samples);
592 let temp = y[i];
593 y[i] = y[j];
594 y[j] = temp;
595 }
596 }
597 _ => {
598 for i in 0..n_samples {
599 y[i] = rng.gen_range(0..n_classes as i32);
600 }
601 }
602 }
603
604 Ok((X, y))
605 }
606
607 fn generate_regression_dataset(
609 &self,
610 config: &DatasetConfig,
611 ) -> Result<(Array2<f64>, Array1<f64>), SklearsError> {
612 let mut rng = if let Some(seed) = config.properties.random_state {
613 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
614 } else {
615 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
616 };
617
618 let n_samples = config.size.n_samples;
619 let n_features = config.size.n_features;
620
621 let mut X = Array2::<f64>::zeros((n_samples, n_features));
623 for i in 0..n_samples {
624 for j in 0..n_features {
625 X[[i, j]] = rng.gen_range(-2.0..2.0);
626 }
627 }
628
629 let mut y = Array1::<f64>::zeros(n_samples);
631 for i in 0..n_samples {
632 let mut target = 0.0;
633 for j in 0..n_features.min(3) {
634 target += X[[i, j]] * (j + 1) as f64 * 0.3;
636 }
637
638 if config.properties.noise_level > 0.0 {
640 let noise =
641 rng.gen_range(-config.properties.noise_level..config.properties.noise_level);
642 target += noise;
643 }
644
645 y[i] = target;
646 }
647
648 if config.properties.outlier_fraction > 0.0 {
650 let n_outliers = (n_samples as f64 * config.properties.outlier_fraction) as usize;
651 for _ in 0..n_outliers {
652 let idx = rng.gen_range(0..n_samples);
653 y[idx] *= rng.gen_range(3.0..10.0); }
655 }
656
657 Ok((X, y))
658 }
659
660 fn generate_reference_classifier_predictions(
662 &self,
663 X: &Array2<f64>,
664 y: &Array1<i32>,
665 strategy: &ClassifierStrategy,
666 ) -> Result<Array1<i32>, SklearsError> {
667 let n_samples = X.nrows();
668 let mut predictions = Array1::<i32>::zeros(n_samples);
669
670 match strategy {
671 ClassifierStrategy::MostFrequent => {
672 let mut class_counts = HashMap::new();
674 for &label in y {
675 *class_counts.entry(label).or_insert(0) += 1;
676 }
677 let most_frequent = *class_counts
678 .iter()
679 .max_by_key(|(_, &count)| count)
680 .unwrap()
681 .0;
682 predictions.fill(most_frequent);
683 }
684 ClassifierStrategy::Constant => {
685 predictions.fill(y[0]);
687 }
688 _ => {
689 predictions.fill(y[0]); }
692 }
693
694 Ok(predictions)
695 }
696
697 fn generate_reference_regressor_predictions(
699 &self,
700 X: &Array2<f64>,
701 y: &Array1<f64>,
702 strategy: &RegressorStrategy,
703 ) -> Result<Array1<f64>, SklearsError> {
704 let n_samples = X.nrows();
705 let mut predictions = Array1::<f64>::zeros(n_samples);
706
707 match strategy {
708 RegressorStrategy::Mean => {
709 let mean = y.mean().unwrap_or(0.0);
710 predictions.fill(mean);
711 }
712 RegressorStrategy::Median => {
713 let mut sorted_y = y.to_vec();
714 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
715 let median = if sorted_y.len() % 2 == 0 {
716 (sorted_y[sorted_y.len() / 2 - 1] + sorted_y[sorted_y.len() / 2]) / 2.0
717 } else {
718 sorted_y[sorted_y.len() / 2]
719 };
720 predictions.fill(median);
721 }
722 RegressorStrategy::Constant(value) => {
723 predictions.fill(*value);
724 }
725 RegressorStrategy::Quantile(q) => {
726 let mut sorted_y = y.to_vec();
727 sorted_y.sort_by(|a, b| a.partial_cmp(b).unwrap());
728 let index = (*q * (sorted_y.len() - 1) as f64) as usize;
729 let quantile = sorted_y[index.min(sorted_y.len() - 1)];
730 predictions.fill(quantile);
731 }
732 _ => {
733 let mean = y.mean().unwrap_or(0.0);
735 predictions.fill(mean);
736 }
737 }
738
739 Ok(predictions)
740 }
741
742 fn calculate_accuracy(y_true: &Array1<i32>, y_pred: &Array1<i32>) -> f64 {
744 let n_samples = y_true.len();
745 if n_samples == 0 {
746 return 0.0;
747 }
748
749 let correct = y_true
750 .iter()
751 .zip(y_pred.iter())
752 .filter(|(&true_val, &pred_val)| true_val == pred_val)
753 .count();
754 correct as f64 / n_samples as f64
755 }
756
757 fn calculate_r2_score(y_true: &Array1<f64>, y_pred: &Array1<f64>) -> f64 {
759 let n_samples = y_true.len();
760 if n_samples == 0 {
761 return 0.0;
762 }
763
764 let y_mean = y_true.mean().unwrap_or(0.0);
765
766 let ss_res: f64 = y_true
767 .iter()
768 .zip(y_pred.iter())
769 .map(|(true_val, pred_val)| (true_val - pred_val).powi(2))
770 .sum();
771
772 let ss_tot: f64 = y_true.iter().map(|val| (val - y_mean).powi(2)).sum();
773
774 if ss_tot == 0.0 {
775 return 0.0;
776 }
777
778 1.0 - (ss_res / ss_tot)
779 }
780
781 fn calculate_classifier_numerical_accuracy(
783 &self,
784 predictions: &Array1<i32>,
785 reference: &Array1<i32>,
786 ) -> Result<NumericalAccuracy, SklearsError> {
787 let n_samples = predictions.len();
788 if n_samples != reference.len() {
789 return Err(SklearsError::InvalidParameter {
790 name: "predictions".to_string(),
791 reason: "Prediction arrays must have same length".to_string(),
792 });
793 }
794
795 let mse = predictions
796 .iter()
797 .zip(reference.iter())
798 .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).powi(2))
799 .sum::<f64>()
800 / n_samples as f64;
801
802 let mae = predictions
803 .iter()
804 .zip(reference.iter())
805 .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).abs())
806 .sum::<f64>()
807 / n_samples as f64;
808
809 let max_error = predictions
810 .iter()
811 .zip(reference.iter())
812 .map(|(pred, ref_val)| (*pred as f64 - *ref_val as f64).abs())
813 .fold(0.0, f64::max);
814
815 let pred_mean = predictions.iter().map(|&x| x as f64).sum::<f64>() / n_samples as f64;
817 let ref_mean = reference.iter().map(|&x| x as f64).sum::<f64>() / n_samples as f64;
818
819 let numerator: f64 = predictions
820 .iter()
821 .zip(reference.iter())
822 .map(|(pred, ref_val)| (*pred as f64 - pred_mean) * (*ref_val as f64 - ref_mean))
823 .sum();
824
825 let pred_var: f64 = predictions
826 .iter()
827 .map(|&x| (x as f64 - pred_mean).powi(2))
828 .sum();
829
830 let ref_var: f64 = reference
831 .iter()
832 .map(|&x| (x as f64 - ref_mean).powi(2))
833 .sum();
834
835 let correlation = if pred_var > 0.0 && ref_var > 0.0 {
836 numerator / (pred_var * ref_var).sqrt()
837 } else {
838 1.0 };
840
841 Ok(NumericalAccuracy {
842 prediction_mse: mse,
843 prediction_mae: mae,
844 max_absolute_error: max_error,
845 correlation,
846 reproducibility_check: true, })
848 }
849
850 fn calculate_regressor_numerical_accuracy(
852 &self,
853 predictions: &Array1<f64>,
854 reference: &Array1<f64>,
855 ) -> Result<NumericalAccuracy, SklearsError> {
856 let n_samples = predictions.len();
857 if n_samples != reference.len() {
858 return Err(SklearsError::InvalidParameter {
859 name: "predictions".to_string(),
860 reason: "Prediction arrays must have same length".to_string(),
861 });
862 }
863
864 let mse = predictions
865 .iter()
866 .zip(reference.iter())
867 .map(|(pred, ref_val)| (pred - ref_val).powi(2))
868 .sum::<f64>()
869 / n_samples as f64;
870
871 let mae = predictions
872 .iter()
873 .zip(reference.iter())
874 .map(|(pred, ref_val)| (pred - ref_val).abs())
875 .sum::<f64>()
876 / n_samples as f64;
877
878 let max_error = predictions
879 .iter()
880 .zip(reference.iter())
881 .map(|(pred, ref_val)| (pred - ref_val).abs())
882 .fold(0.0, f64::max);
883
884 let pred_mean = predictions.mean().unwrap_or(0.0);
886 let ref_mean = reference.mean().unwrap_or(0.0);
887
888 let numerator: f64 = predictions
889 .iter()
890 .zip(reference.iter())
891 .map(|(pred, ref_val)| (pred - pred_mean) * (ref_val - ref_mean))
892 .sum();
893
894 let pred_var: f64 = predictions.iter().map(|x| (x - pred_mean).powi(2)).sum();
895
896 let ref_var: f64 = reference.iter().map(|x| (x - ref_mean).powi(2)).sum();
897
898 let correlation = if pred_var > 0.0 && ref_var > 0.0 {
899 numerator / (pred_var * ref_var).sqrt()
900 } else {
901 1.0 };
903
904 Ok(NumericalAccuracy {
905 prediction_mse: mse,
906 prediction_mae: mae,
907 max_absolute_error: max_error,
908 correlation,
909 reproducibility_check: true, })
911 }
912
913 fn create_classification_dataset_info(
915 &self,
916 config: &DatasetConfig,
917 X: &Array2<f64>,
918 y: &Array1<i32>,
919 ) -> DatasetInfo {
920 let mut class_distribution = HashMap::new();
921 for &label in y {
922 *class_distribution.entry(label).or_insert(0) += 1;
923 }
924
925 let n_classes = class_distribution.len();
926
927 DatasetInfo {
929 name: config.name.clone(),
930 n_samples: X.nrows(),
931 n_features: X.ncols(),
932 n_classes: Some(n_classes),
933 class_distribution: Some(class_distribution),
934 target_statistics: None,
935 }
936 }
937
938 fn create_regression_dataset_info(
940 &self,
941 config: &DatasetConfig,
942 X: &Array2<f64>,
943 y: &Array1<f64>,
944 ) -> DatasetInfo {
945 let mean = y.mean().unwrap_or(0.0);
946 let variance = y.iter().map(|val| (val - mean).powi(2)).sum::<f64>() / y.len() as f64;
947 let std = variance.sqrt();
948 let min = y.iter().fold(f64::INFINITY, |a, &b| a.min(b));
949 let max = y.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
950
951 let skewness = y
953 .iter()
954 .map(|val| ((val - mean) / std).powi(3))
955 .sum::<f64>()
956 / y.len() as f64;
957 let kurtosis = y
958 .iter()
959 .map(|val| ((val - mean) / std).powi(4))
960 .sum::<f64>()
961 / y.len() as f64;
962
963 DatasetInfo {
965 name: config.name.clone(),
966 n_samples: X.nrows(),
967 n_features: X.ncols(),
968 n_classes: None,
969 class_distribution: None,
970 target_statistics: Some(TargetStatistics {
971 mean,
972 std,
973 min,
974 max,
975 skewness,
976 kurtosis,
977 }),
978 }
979 }
980
981 pub fn generate_report(&self, results: &[BenchmarkResult]) -> String {
983 let mut report = String::new();
984
985 report.push_str("# Sklearn Benchmark Report\n\n");
986 report.push_str(&format!("Generated {} results\n\n", results.len()));
987
988 report.push_str("## Summary\n\n");
989
990 let total_within_tolerance = results
991 .iter()
992 .filter(|r| r.accuracy_comparison.within_tolerance)
993 .count();
994 let tolerance_rate = total_within_tolerance as f64 / results.len() as f64 * 100.0;
995
996 report.push_str(&format!(
997 "- **Accuracy within tolerance**: {}/{} ({:.1}%)\n",
998 total_within_tolerance,
999 results.len(),
1000 tolerance_rate
1001 ));
1002
1003 let avg_speedup_fit = results
1004 .iter()
1005 .map(|r| r.performance_metrics.speedup_fit)
1006 .sum::<f64>()
1007 / results.len() as f64;
1008
1009 let avg_speedup_predict = results
1010 .iter()
1011 .map(|r| r.performance_metrics.speedup_predict)
1012 .sum::<f64>()
1013 / results.len() as f64;
1014
1015 report.push_str(&format!(
1016 "- **Average fit speedup**: {:.2}x\n",
1017 avg_speedup_fit
1018 ));
1019 report.push_str(&format!(
1020 "- **Average predict speedup**: {:.2}x\n",
1021 avg_speedup_predict
1022 ));
1023
1024 report.push_str("\n## Detailed Results\n\n");
1025
1026 for result in results {
1027 report.push_str(&format!(
1028 "### {} on {}\n\n",
1029 result.strategy, result.dataset_info.name
1030 ));
1031
1032 report.push_str("**Accuracy Comparison:**\n");
1033 report.push_str(&format!(
1034 "- Sklears score: {:.6}\n",
1035 result.accuracy_comparison.sklears_score
1036 ));
1037 report.push_str(&format!(
1038 "- Reference score: {:.6}\n",
1039 result.accuracy_comparison.reference_score
1040 ));
1041 report.push_str(&format!(
1042 "- Absolute difference: {:.6}\n",
1043 result.accuracy_comparison.absolute_difference
1044 ));
1045 report.push_str(&format!(
1046 "- Within tolerance: {}\n",
1047 result.accuracy_comparison.within_tolerance
1048 ));
1049
1050 report.push_str("\n**Performance Metrics:**\n");
1051 report.push_str(&format!(
1052 "- Fit time: {:?}\n",
1053 result.performance_metrics.fit_time_sklears
1054 ));
1055 report.push_str(&format!(
1056 "- Predict time: {:?}\n",
1057 result.performance_metrics.predict_time_sklears
1058 ));
1059
1060 report.push_str("\n**Numerical Accuracy:**\n");
1061 report.push_str(&format!(
1062 "- MSE: {:.6}\n",
1063 result.numerical_accuracy.prediction_mse
1064 ));
1065 report.push_str(&format!(
1066 "- MAE: {:.6}\n",
1067 result.numerical_accuracy.prediction_mae
1068 ));
1069 report.push_str(&format!(
1070 "- Correlation: {:.6}\n",
1071 result.numerical_accuracy.correlation
1072 ));
1073
1074 report.push_str("\n---\n\n");
1075 }
1076
1077 report
1078 }
1079}
1080
1081impl Default for SklearnBenchmarkFramework {
1082 fn default() -> Self {
1083 Self::new()
1084 }
1085}
1086
1087#[allow(non_snake_case)]
1088#[cfg(test)]
1089mod tests {
1090 use super::*;
1091
1092 #[test]
1093 fn test_benchmark_framework_creation() {
1094 let framework = SklearnBenchmarkFramework::new();
1095 assert_eq!(framework.config.tolerance, 1e-10);
1096 assert_eq!(framework.config.n_runs, 5);
1097 }
1098
1099 #[test]
1100 fn test_synthetic_dataset_generation() {
1101 let framework = SklearnBenchmarkFramework::new();
1102 let config = DatasetConfig {
1103 name: "test".to_string(),
1104 data_type: DatasetType::Classification { n_classes: 3 },
1105 size: DatasetSize {
1106 n_samples: 100,
1107 n_features: 4,
1108 },
1109 properties: DatasetProperties {
1110 noise_level: 0.1,
1111 correlation: 0.0,
1112 outlier_fraction: 0.0,
1113 random_state: Some(42),
1114 },
1115 };
1116
1117 let (X, y) = framework.generate_classification_dataset(&config).unwrap();
1118 assert_eq!(X.nrows(), 100);
1119 assert_eq!(X.ncols(), 4);
1120 assert_eq!(y.len(), 100);
1121
1122 for &label in &y {
1124 assert!(label >= 0 && label < 3);
1125 }
1126 }
1127
1128 #[test]
1129 fn test_accuracy_calculation() {
1130 let y_true = Array1::from(vec![0, 1, 2, 1, 0]);
1131 let y_pred = Array1::from(vec![0, 1, 1, 1, 0]);
1132
1133 let accuracy = SklearnBenchmarkFramework::calculate_accuracy(&y_true, &y_pred);
1134 assert!((accuracy - 0.8).abs() < 1e-10); }
1136
1137 #[test]
1138 fn test_r2_score_calculation() {
1139 let y_true = Array1::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
1140 let y_pred = Array1::from(vec![1.1, 1.9, 3.1, 3.9, 5.1]);
1141
1142 let r2 = SklearnBenchmarkFramework::calculate_r2_score(&y_true, &y_pred);
1143 assert!(r2 > 0.9); }
1145
1146 #[test]
1147 fn test_benchmark_classifier() {
1148 let framework = SklearnBenchmarkFramework::new();
1149 let results = framework.benchmark_dummy_classifier().unwrap();
1150
1151 assert!(!results.is_empty());
1153
1154 for result in &results {
1156 assert!(!result.strategy.is_empty());
1157 assert!(result.accuracy_comparison.sklears_score >= 0.0);
1158 assert!(result.accuracy_comparison.sklears_score <= 1.0);
1159 }
1160 }
1161
1162 #[test]
1163 fn test_benchmark_regressor() {
1164 let framework = SklearnBenchmarkFramework::new();
1165 let results = framework.benchmark_dummy_regressor().unwrap();
1166
1167 assert!(!results.is_empty());
1169
1170 for result in &results {
1172 assert!(!result.strategy.is_empty());
1173 assert!(result.accuracy_comparison.sklears_score.is_finite());
1175 }
1176 }
1177
1178 #[test]
1179 fn test_report_generation() {
1180 let framework = SklearnBenchmarkFramework::new();
1181 let results = framework.benchmark_dummy_classifier().unwrap();
1182
1183 let report = framework.generate_report(&results);
1184 assert!(report.contains("Sklearn Benchmark Report"));
1185 assert!(report.contains("Summary"));
1186 assert!(report.contains("Detailed Results"));
1187 }
1188}