Skip to main content

use_ml_evaluation/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8    pub use crate::{
9        MlBenchmarkName, MlConfusionMatrixShape, MlEvalSliceKind, MlEvalSliceName,
10        MlEvaluationError, MlEvaluationKind, MlEvaluationRunId, MlEvaluationStatus,
11        MlEvaluationTarget, MlThreshold, MlValidationStrategy,
12    };
13}
14
15macro_rules! evaluation_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlEvaluationError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = MlEvaluationError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = MlEvaluationError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58}
59
60macro_rules! evaluation_enum {
61    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63        pub enum $name {
64            $($variant),+
65        }
66
67        impl $name {
68            pub const fn as_str(self) -> &'static str {
69                match self {
70                    $(Self::$variant => $label),+
71                }
72            }
73        }
74
75        impl fmt::Display for $name {
76            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77                formatter.write_str(self.as_str())
78            }
79        }
80
81        impl FromStr for $name {
82            type Err = MlEvaluationError;
83
84            fn from_str(value: &str) -> Result<Self, Self::Err> {
85                match normalized_label(value)?.as_str() {
86                    $($label => Ok(Self::$variant),)+
87                    _ => Err(MlEvaluationError::UnknownLabel),
88                }
89            }
90        }
91    };
92}
93
94evaluation_text_newtype!(MlEvaluationRunId);
95evaluation_text_newtype!(MlEvalSliceName);
96evaluation_text_newtype!(MlBenchmarkName);
97
98#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
99pub struct MlThreshold(f64);
100
101impl MlThreshold {
102    pub fn new(value: f64) -> Result<Self, MlEvaluationError> {
103        if !value.is_finite() {
104            return Err(MlEvaluationError::NonFinite);
105        }
106        if !(0.0..=1.0).contains(&value) {
107            return Err(MlEvaluationError::OutOfRange);
108        }
109        Ok(Self(value))
110    }
111
112    pub const fn value(self) -> f64 {
113        self.0
114    }
115}
116
117#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
118pub struct MlConfusionMatrixShape {
119    rows: NonZeroUsize,
120    columns: NonZeroUsize,
121}
122
123impl MlConfusionMatrixShape {
124    pub fn new(rows: usize, columns: usize) -> Result<Self, MlEvaluationError> {
125        Ok(Self {
126            rows: NonZeroUsize::new(rows).ok_or(MlEvaluationError::Zero)?,
127            columns: NonZeroUsize::new(columns).ok_or(MlEvaluationError::Zero)?,
128        })
129    }
130
131    pub const fn rows(self) -> usize {
132        self.rows.get()
133    }
134
135    pub const fn columns(self) -> usize {
136        self.columns.get()
137    }
138
139    pub const fn is_square(self) -> bool {
140        self.rows.get() == self.columns.get()
141    }
142}
143
144evaluation_enum!(MlEvaluationKind {
145    Offline => "offline",
146    Online => "online",
147    Shadow => "shadow",
148    ABTest => "a-b-test",
149    Backtest => "backtest",
150    CrossValidation => "cross-validation",
151    Holdout => "holdout",
152    Benchmark => "benchmark",
153    HumanEval => "human-eval",
154    Other => "other",
155});
156
157evaluation_enum!(MlValidationStrategy {
158    Holdout => "holdout",
159    KFold => "k-fold",
160    StratifiedKFold => "stratified-k-fold",
161    TimeSeriesSplit => "time-series-split",
162    LeaveOneOut => "leave-one-out",
163    Bootstrap => "bootstrap",
164    Custom => "custom",
165});
166
167evaluation_enum!(MlEvaluationStatus {
168    Pending => "pending",
169    Running => "running",
170    Succeeded => "succeeded",
171    Failed => "failed",
172    Cancelled => "cancelled",
173    Inconclusive => "inconclusive",
174});
175
176evaluation_enum!(MlEvaluationTarget {
177    Model => "model",
178    Pipeline => "pipeline",
179    Dataset => "dataset",
180    Feature => "feature",
181    Label => "label",
182    Artifact => "artifact",
183    TrainingRun => "training-run",
184    Other => "other",
185});
186
187evaluation_enum!(MlEvalSliceKind {
188    Global => "global",
189    Class => "class",
190    Segment => "segment",
191    Cohort => "cohort",
192    Geography => "geography",
193    TimeWindow => "time-window",
194    Device => "device",
195    Language => "language",
196    Custom => "custom",
197});
198
199#[derive(Clone, Copy, Debug, Eq, PartialEq)]
200pub enum MlEvaluationError {
201    Empty,
202    NonFinite,
203    OutOfRange,
204    Zero,
205    UnknownLabel,
206}
207
208impl fmt::Display for MlEvaluationError {
209    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
210        match self {
211            Self::Empty => formatter.write_str("ML evaluation metadata text cannot be empty"),
212            Self::NonFinite => formatter.write_str("ML evaluation value must be finite"),
213            Self::OutOfRange => formatter.write_str("ML evaluation threshold must be in 0.0..=1.0"),
214            Self::Zero => formatter.write_str("ML evaluation count must be positive"),
215            Self::UnknownLabel => formatter.write_str("unknown ML evaluation metadata label"),
216        }
217    }
218}
219
220impl Error for MlEvaluationError {}
221
222fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlEvaluationError> {
223    let trimmed = value.as_ref().trim();
224    if trimmed.is_empty() {
225        Err(MlEvaluationError::Empty)
226    } else {
227        Ok(trimmed.to_string())
228    }
229}
230
231fn normalized_label(value: &str) -> Result<String, MlEvaluationError> {
232    let trimmed = value.trim();
233    if trimmed.is_empty() {
234        Err(MlEvaluationError::Empty)
235    } else {
236        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::{
243        MlConfusionMatrixShape, MlEvaluationError, MlEvaluationKind, MlEvaluationRunId,
244        MlEvaluationStatus, MlThreshold, MlValidationStrategy,
245    };
246
247    #[test]
248    fn validates_evaluation_run_ids() -> Result<(), MlEvaluationError> {
249        let run_id = MlEvaluationRunId::new(" eval-001 ")?;
250
251        assert_eq!(run_id.as_str(), "eval-001");
252        assert_eq!("eval-001".parse::<MlEvaluationRunId>()?, run_id);
253        Ok(())
254    }
255
256    #[test]
257    fn validates_thresholds_and_confusion_matrix_shapes() -> Result<(), MlEvaluationError> {
258        assert_eq!(MlThreshold::new(0.0)?.value(), 0.0);
259        assert_eq!(MlThreshold::new(1.0)?.value(), 1.0);
260        assert_eq!(MlThreshold::new(-0.1), Err(MlEvaluationError::OutOfRange));
261        assert_eq!(MlThreshold::new(1.1), Err(MlEvaluationError::OutOfRange));
262        assert_eq!(
263            MlThreshold::new(f64::NAN),
264            Err(MlEvaluationError::NonFinite)
265        );
266
267        let shape = MlConfusionMatrixShape::new(3, 3)?;
268        assert_eq!(shape.rows(), 3);
269        assert!(shape.is_square());
270        assert_eq!(
271            MlConfusionMatrixShape::new(0, 3),
272            Err(MlEvaluationError::Zero)
273        );
274        Ok(())
275    }
276
277    #[test]
278    fn displays_and_parses_evaluation_enums() -> Result<(), MlEvaluationError> {
279        assert_eq!(
280            "a b test".parse::<MlEvaluationKind>()?,
281            MlEvaluationKind::ABTest
282        );
283        assert_eq!(
284            "stratified k fold".parse::<MlValidationStrategy>()?,
285            MlValidationStrategy::StratifiedKFold
286        );
287        assert_eq!(
288            "cancelled".parse::<MlEvaluationStatus>()?,
289            MlEvaluationStatus::Cancelled
290        );
291        Ok(())
292    }
293}