1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::{error::Error, num::NonZeroUsize};
6
7pub mod prelude {
8 pub use crate::{
9 MlBatchSize, MlCheckpointKind, MlEpochCount, MlHyperparameterName, MlHyperparameterValue,
10 MlLearningRate, MlLossKind, MlOptimizerKind, MlTrainingError, MlTrainingJobName,
11 MlTrainingPhase, MlTrainingRunId, MlTrainingStatus,
12 };
13}
14
15macro_rules! training_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlTrainingError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlTrainingError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlTrainingError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! training_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = MlTrainingError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(MlTrainingError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94training_text_newtype!(MlTrainingRunId);
95training_text_newtype!(MlTrainingJobName);
96training_text_newtype!(MlHyperparameterName);
97training_text_newtype!(MlHyperparameterValue);
98
99#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
100pub struct MlBatchSize(NonZeroUsize);
101
102impl MlBatchSize {
103 pub fn new(value: usize) -> Result<Self, MlTrainingError> {
104 NonZeroUsize::new(value)
105 .map(Self)
106 .ok_or(MlTrainingError::Zero)
107 }
108
109 pub const fn get(self) -> usize {
110 self.0.get()
111 }
112}
113
114#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
115pub struct MlEpochCount(NonZeroUsize);
116
117impl MlEpochCount {
118 pub fn new(value: usize) -> Result<Self, MlTrainingError> {
119 NonZeroUsize::new(value)
120 .map(Self)
121 .ok_or(MlTrainingError::Zero)
122 }
123
124 pub const fn get(self) -> usize {
125 self.0.get()
126 }
127}
128
129#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
130pub struct MlLearningRate(f64);
131
132impl MlLearningRate {
133 pub fn new(value: f64) -> Result<Self, MlTrainingError> {
134 if !value.is_finite() {
135 return Err(MlTrainingError::NonFinite);
136 }
137 if value <= 0.0 {
138 return Err(MlTrainingError::NonPositive);
139 }
140 Ok(Self(value))
141 }
142
143 pub const fn value(self) -> f64 {
144 self.0
145 }
146}
147
148training_enum!(MlTrainingStatus {
149 Queued => "queued",
150 Running => "running",
151 Succeeded => "succeeded",
152 Failed => "failed",
153 Cancelled => "cancelled",
154 TimedOut => "timed-out",
155 Paused => "paused",
156 Unknown => "unknown",
157});
158
159training_enum!(MlTrainingPhase {
160 PrepareData => "prepare-data",
161 Initialize => "initialize",
162 Train => "train",
163 Validate => "validate",
164 Tune => "tune",
165 Checkpoint => "checkpoint",
166 Evaluate => "evaluate",
167 Export => "export",
168 Complete => "complete",
169});
170
171training_enum!(MlOptimizerKind {
172 Sgd => "sgd",
173 Momentum => "momentum",
174 Adam => "adam",
175 AdamW => "adamw",
176 RmsProp => "rmsprop",
177 Adagrad => "adagrad",
178 Adadelta => "adadelta",
179 Lbfgs => "lbfgs",
180 Custom => "custom",
181});
182
183training_enum!(MlLossKind {
184 CrossEntropy => "cross-entropy",
185 BinaryCrossEntropy => "binary-cross-entropy",
186 MeanSquaredError => "mean-squared-error",
187 MeanAbsoluteError => "mean-absolute-error",
188 Huber => "huber",
189 Hinge => "hinge",
190 Triplet => "triplet",
191 Contrastive => "contrastive",
192 Custom => "custom",
193});
194
195training_enum!(MlCheckpointKind {
196 Best => "best",
197 Latest => "latest",
198 Epoch => "epoch",
199 Step => "step",
200 Manual => "manual",
201 Final => "final",
202});
203
204#[derive(Clone, Copy, Debug, Eq, PartialEq)]
205pub enum MlTrainingError {
206 Empty,
207 Zero,
208 NonFinite,
209 NonPositive,
210 UnknownLabel,
211}
212
213impl fmt::Display for MlTrainingError {
214 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
215 match self {
216 Self::Empty => formatter.write_str("ML training metadata text cannot be empty"),
217 Self::Zero => formatter.write_str("ML training count must be positive"),
218 Self::NonFinite => formatter.write_str("ML training value must be finite"),
219 Self::NonPositive => formatter.write_str("ML training value must be positive"),
220 Self::UnknownLabel => formatter.write_str("unknown ML training metadata label"),
221 }
222 }
223}
224
225impl Error for MlTrainingError {}
226
227fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlTrainingError> {
228 let trimmed = value.as_ref().trim();
229 if trimmed.is_empty() {
230 Err(MlTrainingError::Empty)
231 } else {
232 Ok(trimmed.to_string())
233 }
234}
235
236fn normalized_label(value: &str) -> Result<String, MlTrainingError> {
237 let trimmed = value.trim();
238 if trimmed.is_empty() {
239 Err(MlTrainingError::Empty)
240 } else {
241 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::{
248 MlBatchSize, MlCheckpointKind, MlEpochCount, MlLearningRate, MlLossKind, MlOptimizerKind,
249 MlTrainingError, MlTrainingRunId, MlTrainingStatus,
250 };
251
252 #[test]
253 fn validates_training_ids() -> Result<(), MlTrainingError> {
254 let run_id = MlTrainingRunId::new(" run-001 ")?;
255
256 assert_eq!(run_id.as_str(), "run-001");
257 assert_eq!("run-001".parse::<MlTrainingRunId>()?, run_id);
258 Ok(())
259 }
260
261 #[test]
262 fn validates_positive_counts_and_learning_rates() -> Result<(), MlTrainingError> {
263 assert_eq!(MlBatchSize::new(32)?.get(), 32);
264 assert_eq!(MlEpochCount::new(10)?.get(), 10);
265 assert_eq!(MlLearningRate::new(0.001)?.value(), 0.001);
266 assert_eq!(MlBatchSize::new(0), Err(MlTrainingError::Zero));
267 assert_eq!(MlEpochCount::new(0), Err(MlTrainingError::Zero));
268 assert_eq!(MlLearningRate::new(0.0), Err(MlTrainingError::NonPositive));
269 assert_eq!(
270 MlLearningRate::new(f64::NAN),
271 Err(MlTrainingError::NonFinite)
272 );
273 Ok(())
274 }
275
276 #[test]
277 fn displays_and_parses_training_enums() -> Result<(), MlTrainingError> {
278 assert_eq!(
279 "timed out".parse::<MlTrainingStatus>()?,
280 MlTrainingStatus::TimedOut
281 );
282 assert_eq!("adamw".parse::<MlOptimizerKind>()?, MlOptimizerKind::AdamW);
283 assert_eq!(
284 "mean squared error".parse::<MlLossKind>()?,
285 MlLossKind::MeanSquaredError
286 );
287 assert_eq!(MlCheckpointKind::Latest.to_string(), "latest");
288 Ok(())
289 }
290}