1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlBatchingKind, MlConfidenceScore, MlInferenceError, MlInferenceMode, MlInferenceRequestId,
10 MlInferenceStatus, MlInputKind, MlLatencyBucket, MlOutputKind, MlPredictionId,
11 MlServingEndpointName, MlServingKind,
12 };
13}
14
15macro_rules! inference_text_newtype {
16 ($name:ident) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlInferenceError> {
22 non_empty_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlInferenceError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlInferenceError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58}
59
60macro_rules! inference_enum {
61 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
62 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
63 pub enum $name {
64 $($variant),+
65 }
66
67 impl $name {
68 pub const fn as_str(self) -> &'static str {
69 match self {
70 $(Self::$variant => $label),+
71 }
72 }
73 }
74
75 impl fmt::Display for $name {
76 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
77 formatter.write_str(self.as_str())
78 }
79 }
80
81 impl FromStr for $name {
82 type Err = MlInferenceError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 match normalized_label(value)?.as_str() {
86 $($label => Ok(Self::$variant),)+
87 _ => Err(MlInferenceError::UnknownLabel),
88 }
89 }
90 }
91 };
92}
93
94inference_text_newtype!(MlInferenceRequestId);
95inference_text_newtype!(MlPredictionId);
96inference_text_newtype!(MlServingEndpointName);
97
98#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
99pub struct MlConfidenceScore(f64);
100
101impl MlConfidenceScore {
102 pub fn new(value: f64) -> Result<Self, MlInferenceError> {
103 if !value.is_finite() {
104 return Err(MlInferenceError::NonFinite);
105 }
106 if !(0.0..=1.0).contains(&value) {
107 return Err(MlInferenceError::OutOfRange);
108 }
109 Ok(Self(value))
110 }
111
112 pub const fn value(self) -> f64 {
113 self.0
114 }
115}
116
117inference_enum!(MlInferenceMode {
118 Online => "online",
119 Batch => "batch",
120 Streaming => "streaming",
121 Edge => "edge",
122 Offline => "offline",
123});
124
125inference_enum!(MlInferenceStatus {
126 Pending => "pending",
127 Running => "running",
128 Succeeded => "succeeded",
129 Failed => "failed",
130 Cancelled => "cancelled",
131 TimedOut => "timed-out",
132});
133
134inference_enum!(MlServingKind {
135 Local => "local",
136 Embedded => "embedded",
137 Api => "api",
138 BatchJob => "batch-job",
139 StreamProcessor => "stream-processor",
140 EdgeDevice => "edge-device",
141 Browser => "browser",
142 Mobile => "mobile",
143 Other => "other",
144});
145
146inference_enum!(MlInputKind {
147 Text => "text",
148 Image => "image",
149 Audio => "audio",
150 Video => "video",
151 Tabular => "tabular",
152 Json => "json",
153 Tensor => "tensor",
154 Embedding => "embedding",
155 Multimodal => "multimodal",
156 Other => "other",
157});
158
159inference_enum!(MlOutputKind {
160 Class => "class",
161 Score => "score",
162 Ranking => "ranking",
163 Text => "text",
164 Image => "image",
165 Audio => "audio",
166 BoundingBox => "bounding-box",
167 Mask => "mask",
168 Embedding => "embedding",
169 Tensor => "tensor",
170 Json => "json",
171 Other => "other",
172});
173
174inference_enum!(MlBatchingKind {
175 None => "none",
176 Fixed => "fixed",
177 Dynamic => "dynamic",
178 MicroBatch => "micro-batch",
179 Adaptive => "adaptive",
180});
181
182inference_enum!(MlLatencyBucket {
183 Sub10Ms => "sub-10-ms",
184 Sub50Ms => "sub-50-ms",
185 Sub100Ms => "sub-100-ms",
186 Sub500Ms => "sub-500-ms",
187 Sub1s => "sub-1s",
188 Sub5s => "sub-5s",
189 Over5s => "over-5s",
190 Unknown => "unknown",
191});
192
193#[derive(Clone, Copy, Debug, Eq, PartialEq)]
194pub enum MlInferenceError {
195 Empty,
196 NonFinite,
197 OutOfRange,
198 UnknownLabel,
199}
200
201impl fmt::Display for MlInferenceError {
202 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
203 match self {
204 Self::Empty => formatter.write_str("ML inference metadata text cannot be empty"),
205 Self::NonFinite => formatter.write_str("ML confidence score must be finite"),
206 Self::OutOfRange => formatter.write_str("ML confidence score must be in 0.0..=1.0"),
207 Self::UnknownLabel => formatter.write_str("unknown ML inference metadata label"),
208 }
209 }
210}
211
212impl Error for MlInferenceError {}
213
214fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlInferenceError> {
215 let trimmed = value.as_ref().trim();
216 if trimmed.is_empty() {
217 Err(MlInferenceError::Empty)
218 } else {
219 Ok(trimmed.to_string())
220 }
221}
222
223fn normalized_label(value: &str) -> Result<String, MlInferenceError> {
224 let trimmed = value.trim();
225 if trimmed.is_empty() {
226 Err(MlInferenceError::Empty)
227 } else {
228 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::{
235 MlBatchingKind, MlConfidenceScore, MlInferenceError, MlInferenceMode, MlInferenceRequestId,
236 MlInferenceStatus, MlInputKind, MlLatencyBucket, MlOutputKind, MlServingKind,
237 };
238
239 #[test]
240 fn validates_inference_request_ids() -> Result<(), MlInferenceError> {
241 let request = MlInferenceRequestId::new(" req-001 ")?;
242
243 assert_eq!(request.as_str(), "req-001");
244 assert_eq!("req-001".parse::<MlInferenceRequestId>()?, request);
245 Ok(())
246 }
247
248 #[test]
249 fn validates_confidence_scores() -> Result<(), MlInferenceError> {
250 assert_eq!(MlConfidenceScore::new(0.0)?.value(), 0.0);
251 assert_eq!(MlConfidenceScore::new(1.0)?.value(), 1.0);
252 assert_eq!(
253 MlConfidenceScore::new(-0.1),
254 Err(MlInferenceError::OutOfRange)
255 );
256 assert_eq!(
257 MlConfidenceScore::new(1.1),
258 Err(MlInferenceError::OutOfRange)
259 );
260 assert_eq!(
261 MlConfidenceScore::new(f64::INFINITY),
262 Err(MlInferenceError::NonFinite)
263 );
264 Ok(())
265 }
266
267 #[test]
268 fn displays_and_parses_inference_enums() -> Result<(), MlInferenceError> {
269 assert_eq!(
270 "online".parse::<MlInferenceMode>()?,
271 MlInferenceMode::Online
272 );
273 assert_eq!(
274 "timed out".parse::<MlInferenceStatus>()?,
275 MlInferenceStatus::TimedOut
276 );
277 assert_eq!(
278 "batch job".parse::<MlServingKind>()?,
279 MlServingKind::BatchJob
280 );
281 assert_eq!("json".parse::<MlInputKind>()?, MlInputKind::Json);
282 assert_eq!(
283 "bounding box".parse::<MlOutputKind>()?,
284 MlOutputKind::BoundingBox
285 );
286 assert_eq!(
287 "micro_batch".parse::<MlBatchingKind>()?,
288 MlBatchingKind::MicroBatch
289 );
290 assert_eq!(MlLatencyBucket::Sub100Ms.to_string(), "sub-100-ms");
291 Ok(())
292 }
293}