1pub mod advanced_ml;
7pub mod classification;
8pub mod clustering;
9pub mod deep_learning;
10pub mod explainability;
11pub mod fairness;
12pub mod gpu;
13pub mod memory_efficient;
14pub mod mlflow;
15pub mod model_selection;
16pub mod parallel;
17pub mod ranking;
18pub mod regression;
19pub mod regression_diagnostics;
20pub mod reporting;
21pub mod robustness;
22pub mod sklearn_compat;
23pub mod statistical_tests;
24pub mod statistics;
25pub mod streaming;
26pub mod tensorboard;
27pub mod time_series;
28pub mod uncertainty;
29pub mod utils;
30pub mod visualization;
31pub mod wandb;
32
33pub use deep_learning::{
35 BleuScore, DeepLearningMetrics, RougeMetrics, RougeScore, RougeType, SimilarityType,
36 VectorizedFidScore, VectorizedInceptionScore, VectorizedPerplexity,
37 VectorizedSemanticSimilarity,
38};
39
40pub use classification::{ConfusionMatrix, MultiClassMetrics, ThresholdMetrics};
42
43pub use ranking::IRMetrics;
45
46pub use uncertainty::{
48 BayesianUncertainty, CalibrationMetrics, EnsembleUncertainty, MCDropoutUncertainty,
49 UncertaintyDecomposition,
50};
51
52pub use fairness::FairnessMetrics;
54
55pub use statistics::{BootstrapResult, CrossValidationResult, HypothesisTestResult};
57
58pub use gpu::{GpuAccuracy, GpuBatchMetrics, GpuConfusionMatrix};
60
61pub use parallel::{ParallelAccuracy, ParallelConfusionMatrix, ParallelMetricCollection};
63
64pub use reporting::{ComparisonReport, MetricReport, ReportBuilder, ReportFormat};
66
67pub use memory_efficient::{
69 ChunkedEvaluator, MemoryEfficientAccuracy, MemoryEfficientMAE, MemoryEfficientMSE,
70 OnlineConfusionMatrix, StreamingMetric,
71};
72
73pub use tensorboard::{MetricLogger as TensorBoardLogger, TensorBoardWriter};
75
76pub use mlflow::{ExperimentTracker, MLflowClient, MLflowRun};
78
79pub use visualization::{
81 CalibrationCurvePlot, ConfusionMatrixPlot, ExportFormat, FeatureImportancePlot,
82 InteractiveDashboard, LatexReportBuilder, LearningCurvePlot, MetricComparisonPlot, PRCurvePlot,
83 ROCCurvePlot, VisualizationAggregator,
84};
85
86pub use advanced_ml::{
88 ContinualLearningMetrics, DomainAdaptationMetrics, FewShotMetrics, MetaLearningMetrics,
89};
90
91pub use sklearn_compat::{
93 SklearnAccuracy, SklearnF1Score, SklearnMeanAbsoluteError, SklearnMeanSquaredError,
94 SklearnMetric, SklearnPrecision, SklearnR2Score, SklearnRecall,
95};
96
97pub use wandb::{LogEntry, WandbClient};
99
100pub use model_selection::{
102 AICc, CVModelComparison, CVModelSelection, CVScoreType, ModelComparisonReport,
103 MultiModelComparison, AIC, BIC, HQIC,
104};
105
106pub use statistical_tests::{
108 FiveByTwoCVTest, FriedmanTest, KruskalWallisTest, MannWhitneyTest, McNemarTest, NemenyiTest,
109 PairedTTest, WilcoxonTest,
110};
111
112pub use time_series::{
114 dtw_distance, error_autocorrelation, mape, mase, mean_directional_accuracy, msis, smape,
115 theil_u, tracking_signal,
116};
117
118pub use regression_diagnostics::{
120 breusch_pagan_test, calculate_leverage, condition_number, cooks_distance, dffits,
121 durbin_watson, variance_inflation_factor, RegressionDiagnosticReport, ResidualDiagnostics,
122};
123
124pub use explainability::{
126 attribution_agreement, counterfactual_validity, explanation_completeness,
127 explanation_faithfulness, feature_importance_stability, feature_monotonicity,
128 interaction_strength, ExplainabilityMetrics,
129};
130
131pub use robustness::{
133 adversarial_accuracy, attack_success_rate, certified_robustness_radius, confidence_stability,
134 corruption_robustness, gradient_stability, noise_sensitivity, ood_detection_score,
135 robustness_accuracy_tradeoff, RobustnessReport,
136};
137
138use torsh_tensor::Tensor;
139
140pub trait Metric {
142 fn compute(&self, predictions: &Tensor, targets: &Tensor) -> f64;
144
145 fn reset(&mut self) {}
147
148 fn update(&mut self, _predictions: &Tensor, _targets: &Tensor) {}
150
151 fn name(&self) -> &str;
153}
154
155pub struct MetricCollection {
157 metrics: Vec<Box<dyn Metric>>,
158 results: Vec<(String, f64)>,
159}
160
161impl MetricCollection {
162 pub fn new() -> Self {
164 Self {
165 metrics: Vec::new(),
166 results: Vec::new(),
167 }
168 }
169
170 pub fn add<M: Metric + 'static>(mut self, metric: M) -> Self {
172 self.metrics.push(Box::new(metric));
173 self
174 }
175
176 pub fn compute(&mut self, predictions: &Tensor, targets: &Tensor) -> Vec<(String, f64)> {
178 self.results.clear();
179
180 for metric in &self.metrics {
181 let name = metric.name().to_string();
182 let value = metric.compute(predictions, targets);
183 self.results.push((name, value));
184 }
185
186 self.results.clone()
187 }
188
189 pub fn reset(&mut self) {
191 for metric in &mut self.metrics {
192 metric.reset();
193 }
194 self.results.clear();
195 }
196
197 pub fn format_results(&self) -> String {
199 self.results
200 .iter()
201 .map(|(name, value)| format!("{}: {:.4}", name, value))
202 .collect::<Vec<_>>()
203 .join(", ")
204 }
205}
206
207impl Default for MetricCollection {
208 fn default() -> Self {
209 Self::new()
210 }
211}