Skip to main content

torsh_metrics/
lib.rs

1//! Comprehensive evaluation metrics for ToRSh
2//!
3//! This module provides PyTorch-compatible metrics for model evaluation,
4//! built on top of SciRS2's comprehensive metrics library.
5
6pub mod advanced_ml;
7pub mod classification;
8pub mod clustering;
9pub mod deep_learning;
10pub mod explainability;
11pub mod fairness;
12pub mod gpu;
13pub mod memory_efficient;
14pub mod mlflow;
15pub mod model_selection;
16pub mod parallel;
17pub mod ranking;
18pub mod regression;
19pub mod regression_diagnostics;
20pub mod reporting;
21pub mod robustness;
22pub mod sklearn_compat;
23pub mod statistical_tests;
24pub mod statistics;
25pub mod streaming;
26pub mod tensorboard;
27pub mod time_series;
28pub mod uncertainty;
29pub mod utils;
30pub mod visualization;
31pub mod wandb;
32
33// Re-export high-performance vectorized metrics for convenience
34pub use deep_learning::{
35    BleuScore, DeepLearningMetrics, RougeMetrics, RougeScore, RougeType, SimilarityType,
36    VectorizedFidScore, VectorizedInceptionScore, VectorizedPerplexity,
37    VectorizedSemanticSimilarity,
38};
39
40// Re-export classification metrics
41pub use classification::{ConfusionMatrix, MultiClassMetrics, ThresholdMetrics};
42
43// Re-export ranking/IR metrics
44pub use ranking::IRMetrics;
45
46// Re-export uncertainty quantification metrics
47pub use uncertainty::{
48    BayesianUncertainty, CalibrationMetrics, EnsembleUncertainty, MCDropoutUncertainty,
49    UncertaintyDecomposition,
50};
51
52// Re-export fairness metrics
53pub use fairness::FairnessMetrics;
54
55// Re-export statistical metrics
56pub use statistics::{BootstrapResult, CrossValidationResult, HypothesisTestResult};
57
58// Re-export GPU-accelerated metrics
59pub use gpu::{GpuAccuracy, GpuBatchMetrics, GpuConfusionMatrix};
60
61// Re-export parallel metrics
62pub use parallel::{ParallelAccuracy, ParallelConfusionMatrix, ParallelMetricCollection};
63
64// Re-export reporting utilities
65pub use reporting::{ComparisonReport, MetricReport, ReportBuilder, ReportFormat};
66
67// Re-export memory-efficient metrics
68pub use memory_efficient::{
69    ChunkedEvaluator, MemoryEfficientAccuracy, MemoryEfficientMAE, MemoryEfficientMSE,
70    OnlineConfusionMatrix, StreamingMetric,
71};
72
73// Re-export TensorBoard integration
74pub use tensorboard::{MetricLogger as TensorBoardLogger, TensorBoardWriter};
75
76// Re-export MLflow integration
77pub use mlflow::{ExperimentTracker, MLflowClient, MLflowRun};
78
79// Re-export visualization utilities
80pub use visualization::{
81    CalibrationCurvePlot, ConfusionMatrixPlot, ExportFormat, FeatureImportancePlot,
82    InteractiveDashboard, LatexReportBuilder, LearningCurvePlot, MetricComparisonPlot, PRCurvePlot,
83    ROCCurvePlot, VisualizationAggregator,
84};
85
86// Re-export advanced ML metrics
87pub use advanced_ml::{
88    ContinualLearningMetrics, DomainAdaptationMetrics, FewShotMetrics, MetaLearningMetrics,
89};
90
91// Re-export scikit-learn compatibility layer
92pub use sklearn_compat::{
93    SklearnAccuracy, SklearnF1Score, SklearnMeanAbsoluteError, SklearnMeanSquaredError,
94    SklearnMetric, SklearnPrecision, SklearnR2Score, SklearnRecall,
95};
96
97// Re-export Weights & Biases integration
98pub use wandb::{LogEntry, WandbClient};
99
100// Re-export model selection metrics
101pub use model_selection::{
102    AICc, CVModelComparison, CVModelSelection, CVScoreType, ModelComparisonReport,
103    MultiModelComparison, AIC, BIC, HQIC,
104};
105
106// Re-export statistical tests
107pub use statistical_tests::{
108    FiveByTwoCVTest, FriedmanTest, KruskalWallisTest, MannWhitneyTest, McNemarTest, NemenyiTest,
109    PairedTTest, WilcoxonTest,
110};
111
112// Re-export time series metrics
113pub use time_series::{
114    dtw_distance, error_autocorrelation, mape, mase, mean_directional_accuracy, msis, smape,
115    theil_u, tracking_signal,
116};
117
118// Re-export regression diagnostics
119pub use regression_diagnostics::{
120    breusch_pagan_test, calculate_leverage, condition_number, cooks_distance, dffits,
121    durbin_watson, variance_inflation_factor, RegressionDiagnosticReport, ResidualDiagnostics,
122};
123
124// Re-export explainability metrics
125pub use explainability::{
126    attribution_agreement, counterfactual_validity, explanation_completeness,
127    explanation_faithfulness, feature_importance_stability, feature_monotonicity,
128    interaction_strength, ExplainabilityMetrics,
129};
130
131// Re-export robustness metrics
132pub use robustness::{
133    adversarial_accuracy, attack_success_rate, certified_robustness_radius, confidence_stability,
134    corruption_robustness, gradient_stability, noise_sensitivity, ood_detection_score,
135    robustness_accuracy_tradeoff, RobustnessReport,
136};
137
138use torsh_tensor::Tensor;
139
140/// Base trait for all metrics
141pub trait Metric {
142    /// Compute the metric
143    fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f64;
144
145    /// Reset internal state (for stateful metrics)
146    fn reset(&mut self) {}
147
148    /// Update internal state with new batch
149    fn update(&mut self, _predictions: &Tensor, _targets: &Tensor) {}
150
151    /// Get the name of the metric
152    fn name(&self) -> &str;
153}
154
155/// Metric collection for evaluating multiple metrics at once
156pub struct MetricCollection {
157    metrics: Vec<Box<dyn Metric>>,
158    results: Vec<(String, f64)>,
159}
160
161impl MetricCollection {
162    /// Create a new metric collection
163    pub fn new() -> Self {
164        Self {
165            metrics: Vec::new(),
166            results: Vec::new(),
167        }
168    }
169
170    /// Add a metric to the collection
171    pub fn add<M: Metric + 'static>(mut self, metric: M) -> Self {
172        self.metrics.push(Box::new(metric));
173        self
174    }
175
176    /// Compute all metrics
177    pub fn compute(&mut self, predictions: &Tensor, targets: &Tensor) -> Vec<(String, f64)> {
178        self.results.clear();
179
180        for metric in &self.metrics {
181            let name = metric.name().to_string();
182            let value = metric.compute(predictions, targets);
183            self.results.push((name, value));
184        }
185
186        self.results.clone()
187    }
188
189    /// Reset all metrics
190    pub fn reset(&mut self) {
191        for metric in &mut self.metrics {
192            metric.reset();
193        }
194        self.results.clear();
195    }
196
197    /// Get results as a formatted string
198    pub fn format_results(&self) -> String {
199        self.results
200            .iter()
201            .map(|(name, value)| format!("{}: {:.4}", name, value))
202            .collect::<Vec<_>>()
203            .join(", ")
204    }
205}
206
207impl Default for MetricCollection {
208    fn default() -> Self {
209        Self::new()
210    }
211}