Skip to main content

use_ml_feature/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        MlFeatureDriftStatus, MlFeatureEncodingKind, MlFeatureError, MlFeatureId, MlFeatureKind,
10        MlFeatureMissingValuePolicy, MlFeatureName, MlFeatureRole, MlFeatureScalingKind,
11        MlFeatureSource, MlFeatureTransformKind,
12    };
13}
14
15macro_rules! feature_text_newtype {
16    ($name:ident, ascii) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlFeatureError> {
22                ascii_safe_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28        }
29
30        impl AsRef<str> for $name {
31            fn as_ref(&self) -> &str {
32                self.as_str()
33            }
34        }
35
36        impl fmt::Display for $name {
37            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
38                formatter.write_str(self.as_str())
39            }
40        }
41
42        impl FromStr for $name {
43            type Err = MlFeatureError;
44
45            fn from_str(value: &str) -> Result<Self, Self::Err> {
46                Self::new(value)
47            }
48        }
49
50        impl TryFrom<&str> for $name {
51            type Error = MlFeatureError;
52
53            fn try_from(value: &str) -> Result<Self, Self::Error> {
54                Self::new(value)
55            }
56        }
57    };
58    ($name:ident, text) => {
59        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
60        pub struct $name(String);
61
62        impl $name {
63            pub fn new(value: impl AsRef<str>) -> Result<Self, MlFeatureError> {
64                non_empty_text(value).map(Self)
65            }
66
67            pub fn as_str(&self) -> &str {
68                &self.0
69            }
70        }
71
72        impl AsRef<str> for $name {
73            fn as_ref(&self) -> &str {
74                self.as_str()
75            }
76        }
77
78        impl fmt::Display for $name {
79            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
80                formatter.write_str(self.as_str())
81            }
82        }
83
84        impl FromStr for $name {
85            type Err = MlFeatureError;
86
87            fn from_str(value: &str) -> Result<Self, Self::Err> {
88                Self::new(value)
89            }
90        }
91
92        impl TryFrom<&str> for $name {
93            type Error = MlFeatureError;
94
95            fn try_from(value: &str) -> Result<Self, Self::Error> {
96                Self::new(value)
97            }
98        }
99    };
100}
101
102macro_rules! feature_enum {
103    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
104        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
105        pub enum $name {
106            $($variant),+
107        }
108
109        impl $name {
110            pub const fn as_str(self) -> &'static str {
111                match self {
112                    $(Self::$variant => $label),+
113                }
114            }
115        }
116
117        impl fmt::Display for $name {
118            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
119                formatter.write_str(self.as_str())
120            }
121        }
122
123        impl FromStr for $name {
124            type Err = MlFeatureError;
125
126            fn from_str(value: &str) -> Result<Self, Self::Err> {
127                match normalized_label(value)?.as_str() {
128                    $($label => Ok(Self::$variant),)+
129                    _ => Err(MlFeatureError::UnknownLabel),
130                }
131            }
132        }
133    };
134}
135
136feature_text_newtype!(MlFeatureName, ascii);
137feature_text_newtype!(MlFeatureId, text);
138
139feature_enum!(MlFeatureKind {
140    Numeric => "numeric",
141    Categorical => "categorical",
142    Ordinal => "ordinal",
143    Boolean => "boolean",
144    Text => "text",
145    Image => "image",
146    Audio => "audio",
147    Video => "video",
148    Timestamp => "timestamp",
149    Geospatial => "geospatial",
150    Vector => "vector",
151    Embedding => "embedding",
152    Graph => "graph",
153    Other => "other",
154});
155
156feature_enum!(MlFeatureRole {
157    Input => "input",
158    Target => "target",
159    Weight => "weight",
160    Group => "group",
161    Timestamp => "timestamp",
162    Identifier => "identifier",
163    Metadata => "metadata",
164    Ignore => "ignore",
165});
166
167feature_enum!(MlFeatureSource {
168    Raw => "raw",
169    Derived => "derived",
170    Aggregated => "aggregated",
171    Joined => "joined",
172    UserProvided => "user-provided",
173    SystemGenerated => "system-generated",
174    External => "external",
175    Synthetic => "synthetic",
176});
177
178feature_enum!(MlFeatureTransformKind {
179    Normalize => "normalize",
180    Standardize => "standardize",
181    Bucketize => "bucketize",
182    OneHotEncode => "one-hot-encode",
183    Tokenize => "tokenize",
184    Embed => "embed",
185    Impute => "impute",
186    Clip => "clip",
187    Log => "log",
188    Custom => "custom",
189});
190
191feature_enum!(MlFeatureEncodingKind {
192    None => "none",
193    OneHot => "one-hot",
194    Ordinal => "ordinal",
195    Label => "label",
196    Binary => "binary",
197    Hashing => "hashing",
198    Token => "token",
199    Embedding => "embedding",
200    Custom => "custom",
201});
202
203feature_enum!(MlFeatureScalingKind {
204    None => "none",
205    MinMax => "min-max",
206    Standard => "standard",
207    Robust => "robust",
208    UnitNorm => "unit-norm",
209    Log => "log",
210    Custom => "custom",
211});
212
213feature_enum!(MlFeatureMissingValuePolicy {
214    Allow => "allow",
215    Drop => "drop",
216    ImputeMean => "impute-mean",
217    ImputeMedian => "impute-median",
218    ImputeMode => "impute-mode",
219    ImputeConstant => "impute-constant",
220    Unknown => "unknown",
221});
222
223feature_enum!(MlFeatureDriftStatus {
224    Unknown => "unknown",
225    Stable => "stable",
226    Warning => "warning",
227    Drifted => "drifted",
228});
229
230#[derive(Clone, Copy, Debug, Eq, PartialEq)]
231pub enum MlFeatureError {
232    Empty,
233    NonAsciiSafeName,
234    UnknownLabel,
235}
236
237impl fmt::Display for MlFeatureError {
238    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
239        match self {
240            Self::Empty => formatter.write_str("ML feature metadata text cannot be empty"),
241            Self::NonAsciiSafeName => formatter.write_str("ML feature name must be ASCII-safe"),
242            Self::UnknownLabel => formatter.write_str("unknown ML feature metadata label"),
243        }
244    }
245}
246
247impl Error for MlFeatureError {}
248
249fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlFeatureError> {
250    let trimmed = value.as_ref().trim();
251    if trimmed.is_empty() {
252        Err(MlFeatureError::Empty)
253    } else {
254        Ok(trimmed.to_string())
255    }
256}
257
258fn ascii_safe_text(value: impl AsRef<str>) -> Result<String, MlFeatureError> {
259    let trimmed = non_empty_text(value)?;
260    if trimmed
261        .bytes()
262        .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'_' | b'-' | b'.' | b'/'))
263    {
264        Ok(trimmed)
265    } else {
266        Err(MlFeatureError::NonAsciiSafeName)
267    }
268}
269
270fn normalized_label(value: &str) -> Result<String, MlFeatureError> {
271    let trimmed = value.trim();
272    if trimmed.is_empty() {
273        Err(MlFeatureError::Empty)
274    } else {
275        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::{
282        MlFeatureDriftStatus, MlFeatureEncodingKind, MlFeatureError, MlFeatureKind,
283        MlFeatureMissingValuePolicy, MlFeatureName, MlFeatureRole, MlFeatureScalingKind,
284        MlFeatureTransformKind,
285    };
286
287    #[test]
288    fn validates_ascii_safe_feature_names() -> Result<(), MlFeatureError> {
289        let name = MlFeatureName::new(" sepal_width ")?;
290
291        assert_eq!(name.as_str(), "sepal_width");
292        assert_eq!(name.to_string(), "sepal_width");
293        assert_eq!("sepal_width".parse::<MlFeatureName>()?, name);
294        Ok(())
295    }
296
297    #[test]
298    fn rejects_invalid_feature_names() {
299        assert_eq!(MlFeatureName::new("  "), Err(MlFeatureError::Empty));
300        assert_eq!(
301            MlFeatureName::new("prompt variable"),
302            Err(MlFeatureError::NonAsciiSafeName)
303        );
304        assert_eq!(
305            MlFeatureName::new("城市"),
306            Err(MlFeatureError::NonAsciiSafeName)
307        );
308    }
309
310    #[test]
311    fn displays_and_parses_feature_enums() -> Result<(), MlFeatureError> {
312        assert_eq!(
313            "one_hot".parse::<MlFeatureEncodingKind>()?,
314            MlFeatureEncodingKind::OneHot
315        );
316        assert_eq!(
317            "unit norm".parse::<MlFeatureScalingKind>()?,
318            MlFeatureScalingKind::UnitNorm
319        );
320        assert_eq!(
321            "timestamp".parse::<MlFeatureRole>()?,
322            MlFeatureRole::Timestamp
323        );
324        assert_eq!("numeric".parse::<MlFeatureKind>()?, MlFeatureKind::Numeric);
325        assert_eq!(
326            "drifted".parse::<MlFeatureDriftStatus>()?,
327            MlFeatureDriftStatus::Drifted
328        );
329        assert_eq!(
330            "impute mean".parse::<MlFeatureMissingValuePolicy>()?,
331            MlFeatureMissingValuePolicy::ImputeMean
332        );
333        assert_eq!(
334            MlFeatureTransformKind::OneHotEncode.to_string(),
335            "one-hot-encode"
336        );
337        Ok(())
338    }
339}