1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlFeatureDriftStatus, MlFeatureEncodingKind, MlFeatureError, MlFeatureId, MlFeatureKind,
10 MlFeatureMissingValuePolicy, MlFeatureName, MlFeatureRole, MlFeatureScalingKind,
11 MlFeatureSource, MlFeatureTransformKind,
12 };
13}
14
15macro_rules! feature_text_newtype {
16 ($name:ident, ascii) => {
17 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18 pub struct $name(String);
19
20 impl $name {
21 pub fn new(value: impl AsRef<str>) -> Result<Self, MlFeatureError> {
22 ascii_safe_text(value).map(Self)
23 }
24
25 pub fn as_str(&self) -> &str {
26 &self.0
27 }
28 }
29
30 impl AsRef<str> for $name {
31 fn as_ref(&self) -> &str {
32 self.as_str()
33 }
34 }
35
36 impl fmt::Display for $name {
37 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38 formatter.write_str(self.as_str())
39 }
40 }
41
42 impl FromStr for $name {
43 type Err = MlFeatureError;
44
45 fn from_str(value: &str) -> Result<Self, Self::Err> {
46 Self::new(value)
47 }
48 }
49
50 impl TryFrom<&str> for $name {
51 type Error = MlFeatureError;
52
53 fn try_from(value: &str) -> Result<Self, Self::Error> {
54 Self::new(value)
55 }
56 }
57 };
58 ($name:ident, text) => {
59 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
60 pub struct $name(String);
61
62 impl $name {
63 pub fn new(value: impl AsRef<str>) -> Result<Self, MlFeatureError> {
64 non_empty_text(value).map(Self)
65 }
66
67 pub fn as_str(&self) -> &str {
68 &self.0
69 }
70 }
71
72 impl AsRef<str> for $name {
73 fn as_ref(&self) -> &str {
74 self.as_str()
75 }
76 }
77
78 impl fmt::Display for $name {
79 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
80 formatter.write_str(self.as_str())
81 }
82 }
83
84 impl FromStr for $name {
85 type Err = MlFeatureError;
86
87 fn from_str(value: &str) -> Result<Self, Self::Err> {
88 Self::new(value)
89 }
90 }
91
92 impl TryFrom<&str> for $name {
93 type Error = MlFeatureError;
94
95 fn try_from(value: &str) -> Result<Self, Self::Error> {
96 Self::new(value)
97 }
98 }
99 };
100}
101
102macro_rules! feature_enum {
103 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
104 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
105 pub enum $name {
106 $($variant),+
107 }
108
109 impl $name {
110 pub const fn as_str(self) -> &'static str {
111 match self {
112 $(Self::$variant => $label),+
113 }
114 }
115 }
116
117 impl fmt::Display for $name {
118 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
119 formatter.write_str(self.as_str())
120 }
121 }
122
123 impl FromStr for $name {
124 type Err = MlFeatureError;
125
126 fn from_str(value: &str) -> Result<Self, Self::Err> {
127 match normalized_label(value)?.as_str() {
128 $($label => Ok(Self::$variant),)+
129 _ => Err(MlFeatureError::UnknownLabel),
130 }
131 }
132 }
133 };
134}
135
136feature_text_newtype!(MlFeatureName, ascii);
137feature_text_newtype!(MlFeatureId, text);
138
139feature_enum!(MlFeatureKind {
140 Numeric => "numeric",
141 Categorical => "categorical",
142 Ordinal => "ordinal",
143 Boolean => "boolean",
144 Text => "text",
145 Image => "image",
146 Audio => "audio",
147 Video => "video",
148 Timestamp => "timestamp",
149 Geospatial => "geospatial",
150 Vector => "vector",
151 Embedding => "embedding",
152 Graph => "graph",
153 Other => "other",
154});
155
156feature_enum!(MlFeatureRole {
157 Input => "input",
158 Target => "target",
159 Weight => "weight",
160 Group => "group",
161 Timestamp => "timestamp",
162 Identifier => "identifier",
163 Metadata => "metadata",
164 Ignore => "ignore",
165});
166
167feature_enum!(MlFeatureSource {
168 Raw => "raw",
169 Derived => "derived",
170 Aggregated => "aggregated",
171 Joined => "joined",
172 UserProvided => "user-provided",
173 SystemGenerated => "system-generated",
174 External => "external",
175 Synthetic => "synthetic",
176});
177
178feature_enum!(MlFeatureTransformKind {
179 Normalize => "normalize",
180 Standardize => "standardize",
181 Bucketize => "bucketize",
182 OneHotEncode => "one-hot-encode",
183 Tokenize => "tokenize",
184 Embed => "embed",
185 Impute => "impute",
186 Clip => "clip",
187 Log => "log",
188 Custom => "custom",
189});
190
191feature_enum!(MlFeatureEncodingKind {
192 None => "none",
193 OneHot => "one-hot",
194 Ordinal => "ordinal",
195 Label => "label",
196 Binary => "binary",
197 Hashing => "hashing",
198 Token => "token",
199 Embedding => "embedding",
200 Custom => "custom",
201});
202
203feature_enum!(MlFeatureScalingKind {
204 None => "none",
205 MinMax => "min-max",
206 Standard => "standard",
207 Robust => "robust",
208 UnitNorm => "unit-norm",
209 Log => "log",
210 Custom => "custom",
211});
212
213feature_enum!(MlFeatureMissingValuePolicy {
214 Allow => "allow",
215 Drop => "drop",
216 ImputeMean => "impute-mean",
217 ImputeMedian => "impute-median",
218 ImputeMode => "impute-mode",
219 ImputeConstant => "impute-constant",
220 Unknown => "unknown",
221});
222
223feature_enum!(MlFeatureDriftStatus {
224 Unknown => "unknown",
225 Stable => "stable",
226 Warning => "warning",
227 Drifted => "drifted",
228});
229
230#[derive(Clone, Copy, Debug, Eq, PartialEq)]
231pub enum MlFeatureError {
232 Empty,
233 NonAsciiSafeName,
234 UnknownLabel,
235}
236
237impl fmt::Display for MlFeatureError {
238 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
239 match self {
240 Self::Empty => formatter.write_str("ML feature metadata text cannot be empty"),
241 Self::NonAsciiSafeName => formatter.write_str("ML feature name must be ASCII-safe"),
242 Self::UnknownLabel => formatter.write_str("unknown ML feature metadata label"),
243 }
244 }
245}
246
247impl Error for MlFeatureError {}
248
249fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlFeatureError> {
250 let trimmed = value.as_ref().trim();
251 if trimmed.is_empty() {
252 Err(MlFeatureError::Empty)
253 } else {
254 Ok(trimmed.to_string())
255 }
256}
257
258fn ascii_safe_text(value: impl AsRef<str>) -> Result<String, MlFeatureError> {
259 let trimmed = non_empty_text(value)?;
260 if trimmed
261 .bytes()
262 .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'-' | b'.' | b'/'))
263 {
264 Ok(trimmed)
265 } else {
266 Err(MlFeatureError::NonAsciiSafeName)
267 }
268}
269
270fn normalized_label(value: &str) -> Result<String, MlFeatureError> {
271 let trimmed = value.trim();
272 if trimmed.is_empty() {
273 Err(MlFeatureError::Empty)
274 } else {
275 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::{
282 MlFeatureDriftStatus, MlFeatureEncodingKind, MlFeatureError, MlFeatureKind,
283 MlFeatureMissingValuePolicy, MlFeatureName, MlFeatureRole, MlFeatureScalingKind,
284 MlFeatureTransformKind,
285 };
286
287 #[test]
288 fn validates_ascii_safe_feature_names() -> Result<(), MlFeatureError> {
289 let name = MlFeatureName::new(" sepal_width ")?;
290
291 assert_eq!(name.as_str(), "sepal_width");
292 assert_eq!(name.to_string(), "sepal_width");
293 assert_eq!("sepal_width".parse::<MlFeatureName>()?, name);
294 Ok(())
295 }
296
297 #[test]
298 fn rejects_invalid_feature_names() {
299 assert_eq!(MlFeatureName::new(" "), Err(MlFeatureError::Empty));
300 assert_eq!(
301 MlFeatureName::new("prompt variable"),
302 Err(MlFeatureError::NonAsciiSafeName)
303 );
304 assert_eq!(
305 MlFeatureName::new("城市"),
306 Err(MlFeatureError::NonAsciiSafeName)
307 );
308 }
309
310 #[test]
311 fn displays_and_parses_feature_enums() -> Result<(), MlFeatureError> {
312 assert_eq!(
313 "one_hot".parse::<MlFeatureEncodingKind>()?,
314 MlFeatureEncodingKind::OneHot
315 );
316 assert_eq!(
317 "unit norm".parse::<MlFeatureScalingKind>()?,
318 MlFeatureScalingKind::UnitNorm
319 );
320 assert_eq!(
321 "timestamp".parse::<MlFeatureRole>()?,
322 MlFeatureRole::Timestamp
323 );
324 assert_eq!("numeric".parse::<MlFeatureKind>()?, MlFeatureKind::Numeric);
325 assert_eq!(
326 "drifted".parse::<MlFeatureDriftStatus>()?,
327 MlFeatureDriftStatus::Drifted
328 );
329 assert_eq!(
330 "impute mean".parse::<MlFeatureMissingValuePolicy>()?,
331 MlFeatureMissingValuePolicy::ImputeMean
332 );
333 assert_eq!(
334 MlFeatureTransformKind::OneHotEncode.to_string(),
335 "one-hot-encode"
336 );
337 Ok(())
338 }
339}