1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlClassificationMetric, MlClusteringMetric, MlGenerationMetric, MlMetricAggregation,
10 MlMetricDirection, MlMetricError, MlMetricKind, MlMetricName, MlMetricValue,
11 MlRankingMetric, MlRegressionMetric,
12 };
13}
14
15#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
16pub struct MlMetricName(String);
17
18impl MlMetricName {
19 pub fn new(value: impl AsRef<str>) -> Result<Self, MlMetricError> {
20 non_empty_text(value).map(Self)
21 }
22
23 pub fn as_str(&self) -> &str {
24 &self.0
25 }
26}
27
28impl AsRef<str> for MlMetricName {
29 fn as_ref(&self) -> &str {
30 self.as_str()
31 }
32}
33
34impl fmt::Display for MlMetricName {
35 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
36 formatter.write_str(self.as_str())
37 }
38}
39
40impl FromStr for MlMetricName {
41 type Err = MlMetricError;
42
43 fn from_str(value: &str) -> Result<Self, Self::Err> {
44 Self::new(value)
45 }
46}
47
48impl TryFrom<&str> for MlMetricName {
49 type Error = MlMetricError;
50
51 fn try_from(value: &str) -> Result<Self, Self::Error> {
52 Self::new(value)
53 }
54}
55
56#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
57pub struct MlMetricValue(f64);
58
59impl MlMetricValue {
60 pub fn new(value: f64) -> Result<Self, MlMetricError> {
61 if value.is_finite() {
62 Ok(Self(value))
63 } else {
64 Err(MlMetricError::NonFinite)
65 }
66 }
67
68 pub const fn value(self) -> f64 {
69 self.0
70 }
71}
72
73macro_rules! metric_enum {
74 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
75 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
76 pub enum $name {
77 $($variant),+
78 }
79
80 impl $name {
81 pub const fn as_str(self) -> &'static str {
82 match self {
83 $(Self::$variant => $label),+
84 }
85 }
86 }
87
88 impl fmt::Display for $name {
89 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90 formatter.write_str(self.as_str())
91 }
92 }
93
94 impl FromStr for $name {
95 type Err = MlMetricError;
96
97 fn from_str(value: &str) -> Result<Self, Self::Err> {
98 match normalized_label(value)?.as_str() {
99 $($label => Ok(Self::$variant),)+
100 _ => Err(MlMetricError::UnknownLabel),
101 }
102 }
103 }
104 };
105}
106
107metric_enum!(MlMetricKind {
108 Classification => "classification",
109 Regression => "regression",
110 Ranking => "ranking",
111 Clustering => "clustering",
112 Forecasting => "forecasting",
113 Generation => "generation",
114 Retrieval => "retrieval",
115 Calibration => "calibration",
116 Fairness => "fairness",
117 Performance => "performance",
118 Resource => "resource",
119 Other => "other",
120});
121
122metric_enum!(MlMetricDirection {
123 HigherIsBetter => "higher-is-better",
124 LowerIsBetter => "lower-is-better",
125 TargetIsBest => "target-is-best",
126 Unknown => "unknown",
127});
128
129metric_enum!(MlMetricAggregation {
130 Mean => "mean",
131 Median => "median",
132 Min => "min",
133 Max => "max",
134 Sum => "sum",
135 WeightedMean => "weighted-mean",
136 Macro => "macro",
137 Micro => "micro",
138 Samples => "samples",
139 None => "none",
140});
141
142metric_enum!(MlClassificationMetric {
143 Accuracy => "accuracy",
144 Precision => "precision",
145 Recall => "recall",
146 F1 => "f1",
147 RocAuc => "roc-auc",
148 PrAuc => "pr-auc",
149 LogLoss => "log-loss",
150 MatthewsCorrelationCoefficient => "matthews-correlation-coefficient",
151 BalancedAccuracy => "balanced-accuracy",
152});
153
154impl MlClassificationMetric {
155 pub const fn direction(self) -> MlMetricDirection {
156 match self {
157 Self::LogLoss => MlMetricDirection::LowerIsBetter,
158 Self::Accuracy
159 | Self::Precision
160 | Self::Recall
161 | Self::F1
162 | Self::RocAuc
163 | Self::PrAuc
164 | Self::MatthewsCorrelationCoefficient
165 | Self::BalancedAccuracy => MlMetricDirection::HigherIsBetter,
166 }
167 }
168}
169
170metric_enum!(MlRegressionMetric {
171 Mae => "mae",
172 Mse => "mse",
173 Rmse => "rmse",
174 R2 => "r2",
175 Mape => "mape",
176 Smape => "smape",
177 MedianAbsoluteError => "median-absolute-error",
178});
179
180impl MlRegressionMetric {
181 pub const fn direction(self) -> MlMetricDirection {
182 match self {
183 Self::R2 => MlMetricDirection::HigherIsBetter,
184 Self::Mae
185 | Self::Mse
186 | Self::Rmse
187 | Self::Mape
188 | Self::Smape
189 | Self::MedianAbsoluteError => MlMetricDirection::LowerIsBetter,
190 }
191 }
192}
193
194metric_enum!(MlRankingMetric {
195 Ndcg => "ndcg",
196 Map => "map",
197 Mrr => "mrr",
198 HitRate => "hit-rate",
199 RecallAtK => "recall-at-k",
200 PrecisionAtK => "precision-at-k",
201});
202
203impl MlRankingMetric {
204 pub const fn direction(self) -> MlMetricDirection {
205 MlMetricDirection::HigherIsBetter
206 }
207}
208
209metric_enum!(MlClusteringMetric {
210 Silhouette => "silhouette",
211 AdjustedRandIndex => "adjusted-rand-index",
212 NormalizedMutualInfo => "normalized-mutual-info",
213 DaviesBouldin => "davies-bouldin",
214});
215
216impl MlClusteringMetric {
217 pub const fn direction(self) -> MlMetricDirection {
218 match self {
219 Self::DaviesBouldin => MlMetricDirection::LowerIsBetter,
220 Self::Silhouette | Self::AdjustedRandIndex | Self::NormalizedMutualInfo => {
221 MlMetricDirection::HigherIsBetter
222 },
223 }
224 }
225}
226
227metric_enum!(MlGenerationMetric {
228 Bleu => "bleu",
229 Rouge => "rouge",
230 Meteor => "meteor",
231 BertScore => "bert-score",
232 ExactMatch => "exact-match",
233 Perplexity => "perplexity",
234});
235
236impl MlGenerationMetric {
237 pub const fn direction(self) -> MlMetricDirection {
238 match self {
239 Self::Perplexity => MlMetricDirection::LowerIsBetter,
240 Self::Bleu | Self::Rouge | Self::Meteor | Self::BertScore | Self::ExactMatch => {
241 MlMetricDirection::HigherIsBetter
242 },
243 }
244 }
245}
246
247#[derive(Clone, Copy, Debug, Eq, PartialEq)]
248pub enum MlMetricError {
249 Empty,
250 NonFinite,
251 UnknownLabel,
252}
253
254impl fmt::Display for MlMetricError {
255 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
256 match self {
257 Self::Empty => formatter.write_str("ML metric metadata text cannot be empty"),
258 Self::NonFinite => formatter.write_str("ML metric value must be finite"),
259 Self::UnknownLabel => formatter.write_str("unknown ML metric metadata label"),
260 }
261 }
262}
263
264impl Error for MlMetricError {}
265
266fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlMetricError> {
267 let trimmed = value.as_ref().trim();
268 if trimmed.is_empty() {
269 Err(MlMetricError::Empty)
270 } else {
271 Ok(trimmed.to_string())
272 }
273}
274
275fn normalized_label(value: &str) -> Result<String, MlMetricError> {
276 let trimmed = value.trim();
277 if trimmed.is_empty() {
278 Err(MlMetricError::Empty)
279 } else {
280 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::{
287 MlClassificationMetric, MlMetricDirection, MlMetricError, MlMetricName, MlMetricValue,
288 MlRankingMetric, MlRegressionMetric,
289 };
290
291 #[test]
292 fn validates_metric_names_and_values() -> Result<(), MlMetricError> {
293 let name = MlMetricName::new(" accuracy ")?;
294 let value = MlMetricValue::new(0.93)?;
295
296 assert_eq!(name.as_str(), "accuracy");
297 assert_eq!(value.value(), 0.93);
298 assert_eq!(MlMetricName::new(" "), Err(MlMetricError::Empty));
299 assert_eq!(MlMetricValue::new(f64::NAN), Err(MlMetricError::NonFinite));
300 Ok(())
301 }
302
303 #[test]
304 fn displays_parses_and_labels_metric_directions() -> Result<(), MlMetricError> {
305 assert_eq!(
306 "roc auc".parse::<MlClassificationMetric>()?,
307 MlClassificationMetric::RocAuc
308 );
309 assert_eq!(
310 "precision at k".parse::<MlRankingMetric>()?,
311 MlRankingMetric::PrecisionAtK
312 );
313 assert_eq!(
314 "rmse".parse::<MlRegressionMetric>()?,
315 MlRegressionMetric::Rmse
316 );
317 assert_eq!(
318 MlClassificationMetric::Accuracy.direction(),
319 MlMetricDirection::HigherIsBetter
320 );
321 assert_eq!(
322 MlClassificationMetric::LogLoss.direction(),
323 MlMetricDirection::LowerIsBetter
324 );
325 assert_eq!(
326 MlRegressionMetric::R2.direction(),
327 MlMetricDirection::HigherIsBetter
328 );
329 assert_eq!(
330 MlRegressionMetric::Rmse.direction(),
331 MlMetricDirection::LowerIsBetter
332 );
333 Ok(())
334 }
335}