1use crate::SklResult;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9use sklears_core::{error::SklearsError, types::Float};
10use std::collections::HashMap;
11use std::time::{Duration, Instant};
12
13#[cfg(feature = "serde")]
14use chrono::{DateTime, Utc};
15
16#[derive(Debug, Clone)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub struct BenchmarkConfig {
20 pub warmup_iterations: usize,
22 pub benchmark_iterations: usize,
24 pub memory_profiling: bool,
26 pub significance_level: Float,
28 pub categories: Vec<BenchmarkCategory>,
30 pub reference_implementation: Option<String>,
32}
33
34impl Default for BenchmarkConfig {
35 fn default() -> Self {
36 Self {
37 warmup_iterations: 10,
38 benchmark_iterations: 100,
39 memory_profiling: true,
40 significance_level: 0.05,
41 categories: vec![
42 BenchmarkCategory::FeatureImportance,
43 BenchmarkCategory::LocalExplanations,
44 BenchmarkCategory::GlobalExplanations,
45 BenchmarkCategory::Visualization,
46 ],
47 reference_implementation: None,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55pub enum BenchmarkCategory {
56 FeatureImportance,
58 LocalExplanations,
60 GlobalExplanations,
62 Visualization,
64 ModelComparison,
66 UncertaintyQuantification,
68 All,
70}
71
72#[derive(Debug, Clone)]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75pub struct BenchmarkResult {
76 pub method_name: String,
78 pub category: BenchmarkCategory,
80 pub timing_stats: TimingStatistics,
82 pub memory_stats: Option<MemoryStatistics>,
84 pub quality_metrics: QualityMetrics,
86 pub reference_comparison: Option<ReferenceComparison>,
88 pub test_config: TestConfiguration,
90 #[cfg(feature = "serde")]
92 pub timestamp: DateTime<Utc>,
93}
94
95#[derive(Debug, Clone)]
97#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
98pub struct TimingStatistics {
99 pub mean_time: Duration,
101 pub std_dev: Duration,
103 pub median_time: Duration,
105 pub min_time: Duration,
107 pub max_time: Duration,
109 pub percentile_95: Duration,
111 pub throughput: Float,
113}
114
115#[derive(Debug, Clone)]
117#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
118pub struct MemoryStatistics {
119 pub peak_memory: usize,
121 pub avg_memory: usize,
123 pub allocations: usize,
125 pub deallocations: usize,
127 pub efficiency_score: Float,
129}
130
131#[derive(Debug, Clone)]
133#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
134pub struct QualityMetrics {
135 pub fidelity: Float,
137 pub stability: Float,
139 pub consistency: Float,
141 pub completeness: Float,
143 pub interpretability: Float,
145 pub overall_score: Float,
147}
148
149#[derive(Debug, Clone)]
151#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
152pub struct ReferenceComparison {
153 pub speed_improvement: Float,
155 pub memory_improvement: Float,
157 pub quality_difference: Float,
159 pub is_significant: bool,
161 pub p_value: Float,
163}
164
165#[derive(Debug, Clone)]
167#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
168pub struct TestConfiguration {
169 pub dataset_size: usize,
171 pub num_features: usize,
173 pub model_type: String,
175 pub problem_type: ProblemType,
177 pub parameters: HashMap<String, String>,
179}
180
181#[derive(Debug, Clone, Copy)]
183#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
184pub enum ProblemType {
185 BinaryClassification,
187 MultiClassification,
189 Regression,
191}
192
193pub struct BenchmarkingSuite {
195 config: BenchmarkConfig,
196 results: Vec<BenchmarkResult>,
197 reference_results: HashMap<String, BenchmarkResult>,
198}
199
200impl BenchmarkingSuite {
201 pub fn new(config: BenchmarkConfig) -> Self {
203 Self {
204 config,
205 results: Vec::new(),
206 reference_results: HashMap::new(),
207 }
208 }
209
210 pub fn run_benchmarks(&mut self) -> SklResult<BenchmarkReport> {
212 for category in &self.config.categories.clone() {
213 self.benchmark_category(*category)?;
214 }
215
216 self.generate_report()
217 }
218
219 pub fn benchmark_method<F, T>(
221 &mut self,
222 method_name: String,
223 category: BenchmarkCategory,
224 method_fn: F,
225 test_config: TestConfiguration,
226 ) -> SklResult<BenchmarkResult>
227 where
228 F: Fn() -> SklResult<T>,
229 T: std::fmt::Debug,
230 {
231 for _ in 0..self.config.warmup_iterations {
233 let _ = method_fn()?;
234 }
235
236 let mut times = Vec::new();
238 let mut memory_snapshots = Vec::new();
239
240 for _ in 0..self.config.benchmark_iterations {
241 let memory_before = if self.config.memory_profiling {
242 Some(self.get_memory_usage())
243 } else {
244 None
245 };
246
247 let start_time = Instant::now();
248 let _result = method_fn()?;
249 let elapsed = start_time.elapsed();
250
251 times.push(elapsed);
252
253 if let Some(mem_before) = memory_before {
254 let memory_after = self.get_memory_usage();
255 memory_snapshots.push((mem_before, memory_after));
256 }
257 }
258
259 let timing_stats = self.calculate_timing_statistics(×);
261
262 let memory_stats = if self.config.memory_profiling {
264 Some(self.calculate_memory_statistics(&memory_snapshots))
265 } else {
266 None
267 };
268
269 let quality_metrics = self.calculate_quality_metrics(&method_name, &test_config)?;
271
272 let reference_comparison = if let Some(ref_name) = &self.config.reference_implementation {
274 self.reference_results.get(ref_name).map(|ref_result| {
275 self.compare_with_reference(
276 &timing_stats,
277 &memory_stats,
278 &quality_metrics,
279 ref_result,
280 )
281 })
282 } else {
283 None
284 };
285
286 let result = BenchmarkResult {
287 method_name,
288 category,
289 timing_stats,
290 memory_stats,
291 quality_metrics,
292 reference_comparison,
293 test_config,
294 #[cfg(feature = "serde")]
295 timestamp: Utc::now(),
296 };
297
298 self.results.push(result.clone());
299 Ok(result)
300 }
301
302 pub fn benchmark_feature_importance(&mut self) -> SklResult<Vec<BenchmarkResult>> {
304 let mut results = Vec::new();
305
306 let test_configs = vec![
307 TestConfiguration {
308 dataset_size: 1000,
309 num_features: 10,
310 model_type: "RandomForest".to_string(),
311 problem_type: ProblemType::BinaryClassification,
312 parameters: HashMap::new(),
313 },
314 TestConfiguration {
315 dataset_size: 5000,
316 num_features: 50,
317 model_type: "RandomForest".to_string(),
318 problem_type: ProblemType::MultiClassification,
319 parameters: HashMap::new(),
320 },
321 ];
322
323 for config in test_configs {
324 let config_clone = config.clone();
326 let perm_result = self.benchmark_method(
327 "PermutationImportance".to_string(),
328 BenchmarkCategory::FeatureImportance,
329 move || Self::simulate_permutation_importance_static(&config_clone),
330 config.clone(),
331 )?;
332 results.push(perm_result);
333
334 let config_clone = config.clone();
336 let shap_result = self.benchmark_method(
337 "SHAP".to_string(),
338 BenchmarkCategory::FeatureImportance,
339 move || Self::simulate_shap_computation_static(&config_clone),
340 config.clone(),
341 )?;
342 results.push(shap_result);
343 }
344
345 Ok(results)
346 }
347
348 pub fn add_reference_result(&mut self, name: String, result: BenchmarkResult) {
350 self.reference_results.insert(name, result);
351 }
352
353 pub fn generate_report(&self) -> SklResult<BenchmarkReport> {
355 let mut category_summaries = HashMap::new();
356
357 for result in &self.results {
359 let category_results = category_summaries
360 .entry(result.category)
361 .or_insert_with(Vec::new);
362 category_results.push(result.clone());
363 }
364
365 let mut summaries = HashMap::new();
367 for (category, results) in category_summaries {
368 summaries.insert(category, self.generate_category_summary(&results));
369 }
370
371 let insights = self.generate_performance_insights();
373
374 let recommendations = self.generate_recommendations();
376
377 Ok(BenchmarkReport {
378 config: self.config.clone(),
379 results: self.results.clone(),
380 category_summaries: summaries,
381 performance_insights: insights,
382 recommendations,
383 #[cfg(feature = "serde")]
384 generated_at: Utc::now(),
385 })
386 }
387
388 fn benchmark_category(&mut self, category: BenchmarkCategory) -> SklResult<()> {
389 match category {
390 BenchmarkCategory::FeatureImportance => {
391 self.benchmark_feature_importance()?;
392 }
393 BenchmarkCategory::LocalExplanations => {
394 self.benchmark_local_explanations()?;
395 }
396 BenchmarkCategory::GlobalExplanations => {
397 self.benchmark_global_explanations()?;
398 }
399 BenchmarkCategory::Visualization => {
400 self.benchmark_visualization()?;
401 }
402 BenchmarkCategory::All => {
403 self.benchmark_feature_importance()?;
404 self.benchmark_local_explanations()?;
405 self.benchmark_global_explanations()?;
406 self.benchmark_visualization()?;
407 }
408 _ => {} }
410 Ok(())
411 }
412
413 fn benchmark_local_explanations(&mut self) -> SklResult<Vec<BenchmarkResult>> {
414 Ok(Vec::new())
416 }
417
418 fn benchmark_global_explanations(&mut self) -> SklResult<Vec<BenchmarkResult>> {
419 Ok(Vec::new())
421 }
422
423 fn benchmark_visualization(&mut self) -> SklResult<Vec<BenchmarkResult>> {
424 Ok(Vec::new())
426 }
427
428 fn calculate_timing_statistics(&self, times: &[Duration]) -> TimingStatistics {
429 if times.is_empty() {
430 return TimingStatistics {
431 mean_time: Duration::from_secs(0),
432 std_dev: Duration::from_secs(0),
433 median_time: Duration::from_secs(0),
434 min_time: Duration::from_secs(0),
435 max_time: Duration::from_secs(0),
436 percentile_95: Duration::from_secs(0),
437 throughput: 0.0,
438 };
439 }
440
441 let mut sorted_times = times.to_vec();
442 sorted_times.sort();
443
444 let mean_nanos: u128 =
445 times.iter().map(|d| d.as_nanos()).sum::<u128>() / times.len() as u128;
446 let mean_time = Duration::from_nanos(mean_nanos as u64);
447
448 let variance: f64 = times
449 .iter()
450 .map(|d| {
451 let diff = d.as_nanos() as f64 - mean_nanos as f64;
452 diff * diff
453 })
454 .sum::<f64>()
455 / times.len() as f64;
456 let std_dev = Duration::from_nanos(variance.sqrt() as u64);
457
458 let median_time = sorted_times[times.len() / 2];
459 let min_time = *sorted_times.first().unwrap();
460 let max_time = *sorted_times.last().unwrap();
461 let percentile_95 = sorted_times[(times.len() as f64 * 0.95) as usize];
462
463 let throughput = if mean_time.as_secs_f64() > 0.0 {
464 1.0 / mean_time.as_secs_f64()
465 } else {
466 0.0
467 };
468
469 TimingStatistics {
470 mean_time,
471 std_dev,
472 median_time,
473 min_time,
474 max_time,
475 percentile_95,
476 throughput,
477 }
478 }
479
480 fn calculate_memory_statistics(&self, snapshots: &[(usize, usize)]) -> MemoryStatistics {
481 if snapshots.is_empty() {
482 return MemoryStatistics {
483 peak_memory: 0,
484 avg_memory: 0,
485 allocations: 0,
486 deallocations: 0,
487 efficiency_score: 0.0,
488 };
489 }
490
491 let peak_memory = snapshots.iter().map(|(_, after)| *after).max().unwrap_or(0);
492
493 let avg_memory = snapshots.iter().map(|(_, after)| *after).sum::<usize>() / snapshots.len();
494
495 MemoryStatistics {
497 peak_memory,
498 avg_memory,
499 allocations: snapshots.len(),
500 deallocations: snapshots.len(),
501 efficiency_score: 0.8, }
503 }
504
505 fn calculate_quality_metrics(
506 &self,
507 method_name: &str,
508 _config: &TestConfiguration,
509 ) -> SklResult<QualityMetrics> {
510 let base_score = match method_name {
513 "PermutationImportance" => 0.85,
514 "SHAP" => 0.90,
515 "LIME" => 0.80,
516 _ => 0.75,
517 };
518
519 Ok(QualityMetrics {
520 fidelity: base_score + 0.05,
521 stability: base_score - 0.02,
522 consistency: base_score + 0.03,
523 completeness: base_score - 0.05,
524 interpretability: base_score + 0.02,
525 overall_score: base_score,
526 })
527 }
528
529 fn compare_with_reference(
530 &self,
531 timing: &TimingStatistics,
532 memory: &Option<MemoryStatistics>,
533 quality: &QualityMetrics,
534 reference: &BenchmarkResult,
535 ) -> ReferenceComparison {
536 let speed_improvement =
537 reference.timing_stats.mean_time.as_secs_f64() / timing.mean_time.as_secs_f64();
538
539 let memory_improvement =
540 if let (Some(current_mem), Some(ref_mem)) = (memory, &reference.memory_stats) {
541 ref_mem.peak_memory as f64 / current_mem.peak_memory as f64
542 } else {
543 1.0
544 };
545
546 let quality_difference = quality.overall_score - reference.quality_metrics.overall_score;
547
548 let p_value = 0.01; let is_significant = p_value < self.config.significance_level;
551
552 ReferenceComparison {
553 speed_improvement,
554 memory_improvement,
555 quality_difference,
556 is_significant,
557 p_value,
558 }
559 }
560
561 fn generate_category_summary(&self, results: &[BenchmarkResult]) -> CategorySummary {
562 if results.is_empty() {
563 return CategorySummary {
564 best_method: "None".to_string(),
565 worst_method: "None".to_string(),
566 avg_throughput: 0.0,
567 avg_quality: 0.0,
568 performance_ranking: Vec::new(),
569 };
570 }
571
572 let best_method = results
573 .iter()
574 .max_by(|a, b| {
575 a.timing_stats
576 .throughput
577 .partial_cmp(&b.timing_stats.throughput)
578 .unwrap()
579 })
580 .unwrap()
581 .method_name
582 .clone();
583
584 let worst_method = results
585 .iter()
586 .min_by(|a, b| {
587 a.timing_stats
588 .throughput
589 .partial_cmp(&b.timing_stats.throughput)
590 .unwrap()
591 })
592 .unwrap()
593 .method_name
594 .clone();
595
596 let avg_throughput = results
597 .iter()
598 .map(|r| r.timing_stats.throughput)
599 .sum::<Float>()
600 / results.len() as Float;
601
602 let avg_quality = results
603 .iter()
604 .map(|r| r.quality_metrics.overall_score)
605 .sum::<Float>()
606 / results.len() as Float;
607
608 let mut performance_ranking: Vec<(String, Float)> = results
609 .iter()
610 .map(|r| (r.method_name.clone(), r.timing_stats.throughput))
611 .collect();
612 performance_ranking.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
613
614 CategorySummary {
615 best_method,
616 worst_method,
617 avg_throughput,
618 avg_quality,
619 performance_ranking,
620 }
621 }
622
623 fn generate_performance_insights(&self) -> Vec<PerformanceInsight> {
624 let mut insights = Vec::new();
625
626 let fastest_method = self.results.iter().max_by(|a, b| {
628 a.timing_stats
629 .throughput
630 .partial_cmp(&b.timing_stats.throughput)
631 .unwrap()
632 });
633
634 if let Some(method) = fastest_method {
635 insights.push(PerformanceInsight {
636 insight_type: InsightType::Speed,
637 message: format!(
638 "{} is the fastest method with {:.2} ops/sec",
639 method.method_name, method.timing_stats.throughput
640 ),
641 severity: InsightSeverity::Info,
642 });
643 }
644
645 let highest_quality = self.results.iter().max_by(|a, b| {
647 a.quality_metrics
648 .overall_score
649 .partial_cmp(&b.quality_metrics.overall_score)
650 .unwrap()
651 });
652
653 if let Some(method) = highest_quality {
654 insights.push(PerformanceInsight {
655 insight_type: InsightType::Quality,
656 message: format!(
657 "{} has the highest quality score: {:.3}",
658 method.method_name, method.quality_metrics.overall_score
659 ),
660 severity: InsightSeverity::Info,
661 });
662 }
663
664 insights
665 }
666
667 fn generate_recommendations(&self) -> Vec<String> {
668 let mut recommendations = Vec::new();
669
670 recommendations.push("Consider using parallel processing for large datasets".to_string());
671 recommendations.push("Enable caching for repeated computations".to_string());
672 recommendations.push("Profile memory usage for memory-intensive operations".to_string());
673
674 recommendations
675 }
676
677 fn get_memory_usage(&self) -> usize {
678 1024 * 1024 }
683
684 fn simulate_permutation_importance_static(config: &TestConfiguration) -> SklResult<String> {
685 let computation_time =
687 Duration::from_millis((config.dataset_size * config.num_features / 100) as u64);
688 std::thread::sleep(computation_time);
689 Ok("Permutation importance computed".to_string())
690 }
691
692 fn simulate_shap_computation_static(config: &TestConfiguration) -> SklResult<String> {
693 let computation_time =
695 Duration::from_millis((config.dataset_size * config.num_features / 50) as u64);
696 std::thread::sleep(computation_time);
697 Ok("SHAP values computed".to_string())
698 }
699}
700
701#[derive(Debug, Clone)]
703#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
704pub struct CategorySummary {
705 pub best_method: String,
707 pub worst_method: String,
709 pub avg_throughput: Float,
711 pub avg_quality: Float,
713 pub performance_ranking: Vec<(String, Float)>,
715}
716
717#[derive(Debug, Clone)]
719#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
720pub struct PerformanceInsight {
721 pub insight_type: InsightType,
723 pub message: String,
725 pub severity: InsightSeverity,
727}
728
729#[derive(Debug, Clone, Copy)]
731#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
732pub enum InsightType {
733 Speed,
735 Memory,
737 Quality,
739 Scalability,
741}
742
743#[derive(Debug, Clone, Copy)]
745#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
746pub enum InsightSeverity {
747 Info,
749 Warning,
751 Critical,
753}
754
755#[derive(Debug, Clone)]
757#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
758pub struct BenchmarkReport {
759 pub config: BenchmarkConfig,
761 pub results: Vec<BenchmarkResult>,
763 pub category_summaries: HashMap<BenchmarkCategory, CategorySummary>,
765 pub performance_insights: Vec<PerformanceInsight>,
767 pub recommendations: Vec<String>,
769 #[cfg(feature = "serde")]
771 pub generated_at: DateTime<Utc>,
772}
773
774impl BenchmarkReport {
775 #[cfg(feature = "serde")]
777 pub fn to_json(&self) -> SklResult<String> {
778 serde_json::to_string_pretty(self)
779 .map_err(|e| SklearsError::InvalidInput(format!("Failed to serialize report: {}", e)))
780 }
781
782 pub fn to_html(&self) -> String {
784 let generated_time = {
785 #[cfg(feature = "serde")]
786 {
787 self.generated_at
788 .format("%Y-%m-%d %H:%M:%S UTC")
789 .to_string()
790 }
791 #[cfg(not(feature = "serde"))]
792 {
793 "N/A".to_string()
794 }
795 };
796
797 format!(
798 r#"<!DOCTYPE html>
799<html>
800<head>
801 <title>Benchmark Report</title>
802 <style>
803 body {{ font-family: Arial, sans-serif; margin: 20px; }}
804 .header {{ background-color: #f0f0f0; padding: 20px; border-radius: 5px; }}
805 .section {{ margin: 20px 0; }}
806 .result {{ background-color: #f9f9f9; padding: 15px; margin: 10px 0; border-radius: 5px; }}
807 .performance-table {{ width: 100%; border-collapse: collapse; }}
808 .performance-table th, .performance-table td {{
809 border: 1px solid #ddd; padding: 8px; text-align: left;
810 }}
811 .performance-table th {{ background-color: #f2f2f2; }}
812 </style>
813</head>
814<body>
815 <div class="header">
816 <h1>Benchmark Report</h1>
817 <p>Generated: {}</p>
818 <p>Total Methods Tested: {}</p>
819 </div>
820
821 <div class="section">
822 <h2>Performance Summary</h2>
823 <table class="performance-table">
824 <tr>
825 <th>Method</th>
826 <th>Category</th>
827 <th>Mean Time (ms)</th>
828 <th>Throughput (ops/sec)</th>
829 <th>Quality Score</th>
830 </tr>
831 {}
832 </table>
833 </div>
834
835 <div class="section">
836 <h2>Insights</h2>
837 {}
838 </div>
839
840 <div class="section">
841 <h2>Recommendations</h2>
842 <ul>
843 {}
844 </ul>
845 </div>
846</body>
847</html>"#,
848 generated_time,
849 self.results.len(),
850 self.results
851 .iter()
852 .map(|r| format!(
853 "<tr><td>{}</td><td>{:?}</td><td>{:.2}</td><td>{:.2}</td><td>{:.3}</td></tr>",
854 r.method_name,
855 r.category,
856 r.timing_stats.mean_time.as_millis(),
857 r.timing_stats.throughput,
858 r.quality_metrics.overall_score
859 ))
860 .collect::<Vec<_>>()
861 .join("\n"),
862 self.performance_insights
863 .iter()
864 .map(|insight| format!(
865 "<p><strong>{:?}:</strong> {}</p>",
866 insight.insight_type, insight.message
867 ))
868 .collect::<Vec<_>>()
869 .join("\n"),
870 self.recommendations
871 .iter()
872 .map(|rec| format!("<li>{}</li>", rec))
873 .collect::<Vec<_>>()
874 .join("\n")
875 )
876 }
877}
878
879#[cfg(test)]
880mod tests {
881 use super::*;
882
883 #[test]
884 fn test_benchmark_config_creation() {
885 let config = BenchmarkConfig::default();
886 assert_eq!(config.warmup_iterations, 10);
887 assert_eq!(config.benchmark_iterations, 100);
888 assert_eq!(config.significance_level, 0.05);
889 }
890
891 #[test]
892 fn test_benchmark_suite_creation() {
893 let config = BenchmarkConfig::default();
894 let suite = BenchmarkingSuite::new(config);
895 assert_eq!(suite.results.len(), 0);
896 assert_eq!(suite.reference_results.len(), 0);
897 }
898
899 #[test]
900 fn test_timing_statistics_calculation() {
901 let config = BenchmarkConfig::default();
902 let suite = BenchmarkingSuite::new(config);
903
904 let times = vec![
905 Duration::from_millis(100),
906 Duration::from_millis(200),
907 Duration::from_millis(150),
908 ];
909
910 let stats = suite.calculate_timing_statistics(×);
911 assert!(stats.mean_time.as_millis() > 0);
912 assert!(stats.throughput > 0.0);
913 }
914
915 #[test]
916 fn test_quality_metrics_calculation() {
917 let config = BenchmarkConfig::default();
918 let suite = BenchmarkingSuite::new(config);
919
920 let test_config = TestConfiguration {
921 dataset_size: 1000,
922 num_features: 10,
923 model_type: "Test".to_string(),
924 problem_type: ProblemType::BinaryClassification,
925 parameters: HashMap::new(),
926 };
927
928 let metrics = suite
929 .calculate_quality_metrics("SHAP", &test_config)
930 .unwrap();
931 assert!(metrics.overall_score > 0.0);
932 assert!(metrics.fidelity > 0.0);
933 }
934
935 #[test]
936 fn test_memory_statistics_calculation() {
937 let config = BenchmarkConfig::default();
938 let suite = BenchmarkingSuite::new(config);
939
940 let snapshots = vec![(1000, 1500), (1200, 1800), (1100, 1600)];
941 let stats = suite.calculate_memory_statistics(&snapshots);
942
943 assert_eq!(stats.peak_memory, 1800);
944 assert!(stats.avg_memory > 0);
945 }
946
947 #[test]
948 fn test_test_configuration() {
949 let config = TestConfiguration {
950 dataset_size: 5000,
951 num_features: 20,
952 model_type: "RandomForest".to_string(),
953 problem_type: ProblemType::MultiClassification,
954 parameters: HashMap::new(),
955 };
956
957 assert_eq!(config.dataset_size, 5000);
958 assert_eq!(config.num_features, 20);
959 assert!(matches!(
960 config.problem_type,
961 ProblemType::MultiClassification
962 ));
963 }
964
965 #[test]
966 fn test_reference_comparison() {
967 let comparison = ReferenceComparison {
968 speed_improvement: 2.5,
969 memory_improvement: 1.8,
970 quality_difference: 0.05,
971 is_significant: true,
972 p_value: 0.01,
973 };
974
975 assert!(comparison.speed_improvement > 2.0);
976 assert!(comparison.is_significant);
977 assert!(comparison.p_value < 0.05);
978 }
979
980 #[test]
981 fn test_performance_insight() {
982 let insight = PerformanceInsight {
983 insight_type: InsightType::Speed,
984 message: "Method A is 3x faster than Method B".to_string(),
985 severity: InsightSeverity::Info,
986 };
987
988 assert!(matches!(insight.insight_type, InsightType::Speed));
989 assert!(matches!(insight.severity, InsightSeverity::Info));
990 assert!(insight.message.contains("3x faster"));
991 }
992
993 #[test]
994 fn test_html_report_generation() {
995 let config = BenchmarkConfig::default();
996 let report = BenchmarkReport {
997 config,
998 results: Vec::new(),
999 category_summaries: HashMap::new(),
1000 performance_insights: Vec::new(),
1001 recommendations: vec!["Test recommendation".to_string()],
1002 #[cfg(feature = "serde")]
1003 generated_at: Utc::now(),
1004 };
1005
1006 let html = report.to_html();
1007 assert!(html.contains("<!DOCTYPE html>"));
1008 assert!(html.contains("Benchmark Report"));
1009 assert!(html.contains("Test recommendation"));
1010 }
1011}