Skip to main content

use_ml_dataset/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        MlDatasetCardRef, MlDatasetError, MlDatasetId, MlDatasetKind, MlDatasetLicense,
10        MlDatasetName, MlDatasetProvenance, MlDatasetSchemaRef, MlDatasetSplit, MlDatasetVersion,
11        MlExampleId, MlExampleKind,
12    };
13}
14
15macro_rules! dataset_text_newtype {
16    ($name:ident) => {
17        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
18        pub struct $name(String);
19
20        impl $name {
21            pub fn new(value: impl AsRef<str>) -> Result<Self, MlDatasetError> {
22                non_empty_text(value).map(Self)
23            }
24
25            pub fn as_str(&self) -> &str {
26                &self.0
27            }
28
29            pub fn into_string(self) -> String {
30                self.0
31            }
32        }
33
34        impl AsRef<str> for $name {
35            fn as_ref(&self) -> &str {
36                self.as_str()
37            }
38        }
39
40        impl fmt::Display for $name {
41            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
42                formatter.write_str(self.as_str())
43            }
44        }
45
46        impl FromStr for $name {
47            type Err = MlDatasetError;
48
49            fn from_str(value: &str) -> Result<Self, Self::Err> {
50                Self::new(value)
51            }
52        }
53
54        impl TryFrom<&str> for $name {
55            type Error = MlDatasetError;
56
57            fn try_from(value: &str) -> Result<Self, Self::Error> {
58                Self::new(value)
59            }
60        }
61    };
62}
63
64macro_rules! dataset_enum {
65    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
66        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
67        pub enum $name {
68            $($variant),+
69        }
70
71        impl $name {
72            pub const fn as_str(self) -> &'static str {
73                match self {
74                    $(Self::$variant => $label),+
75                }
76            }
77        }
78
79        impl fmt::Display for $name {
80            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
81                formatter.write_str(self.as_str())
82            }
83        }
84
85        impl FromStr for $name {
86            type Err = MlDatasetError;
87
88            fn from_str(value: &str) -> Result<Self, Self::Err> {
89                match normalized_label(value)?.as_str() {
90                    $($label => Ok(Self::$variant),)+
91                    _ => Err(MlDatasetError::UnknownLabel),
92                }
93            }
94        }
95    };
96}
97
98dataset_text_newtype!(MlDatasetName);
99dataset_text_newtype!(MlDatasetId);
100dataset_text_newtype!(MlDatasetVersion);
101dataset_text_newtype!(MlExampleId);
102dataset_text_newtype!(MlDatasetLicense);
103dataset_text_newtype!(MlDatasetSchemaRef);
104dataset_text_newtype!(MlDatasetCardRef);
105
106#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
107pub enum MlDatasetSplit {
108    Train,
109    Validation,
110    Test,
111    Holdout,
112    Calibration,
113    Shadow,
114    Production,
115    Custom(String),
116}
117
118impl fmt::Display for MlDatasetSplit {
119    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
120        match self {
121            Self::Train => formatter.write_str("train"),
122            Self::Validation => formatter.write_str("validation"),
123            Self::Test => formatter.write_str("test"),
124            Self::Holdout => formatter.write_str("holdout"),
125            Self::Calibration => formatter.write_str("calibration"),
126            Self::Shadow => formatter.write_str("shadow"),
127            Self::Production => formatter.write_str("production"),
128            Self::Custom(value) => formatter.write_str(value),
129        }
130    }
131}
132
133impl FromStr for MlDatasetSplit {
134    type Err = MlDatasetError;
135
136    fn from_str(value: &str) -> Result<Self, Self::Err> {
137        let trimmed = value.trim();
138        if trimmed.is_empty() {
139            return Err(MlDatasetError::Empty);
140        }
141
142        Ok(match normalized_label(trimmed)?.as_str() {
143            "train" | "training" => Self::Train,
144            "validation" | "valid" | "val" => Self::Validation,
145            "test" => Self::Test,
146            "holdout" => Self::Holdout,
147            "calibration" => Self::Calibration,
148            "shadow" => Self::Shadow,
149            "production" | "prod" => Self::Production,
150            _ => Self::Custom(trimmed.to_string()),
151        })
152    }
153}
154
155dataset_enum!(MlDatasetKind {
156    Tabular => "tabular",
157    Text => "text",
158    Image => "image",
159    Audio => "audio",
160    Video => "video",
161    TimeSeries => "time-series",
162    Graph => "graph",
163    Multimodal => "multimodal",
164    Synthetic => "synthetic",
165    Other => "other",
166});
167
168dataset_enum!(MlExampleKind {
169    Labeled => "labeled",
170    Unlabeled => "unlabeled",
171    WeaklyLabeled => "weakly-labeled",
172    PseudoLabeled => "pseudo-labeled",
173    Augmented => "augmented",
174});
175
176dataset_enum!(MlDatasetProvenance {
177    HumanCreated => "human-created",
178    MachineGenerated => "machine-generated",
179    Synthetic => "synthetic",
180    Scraped => "scraped",
181    Instrumented => "instrumented",
182    Mixed => "mixed",
183    Unknown => "unknown",
184});
185
186#[derive(Clone, Copy, Debug, Eq, PartialEq)]
187pub enum MlDatasetError {
188    Empty,
189    UnknownLabel,
190}
191
192impl fmt::Display for MlDatasetError {
193    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
194        match self {
195            Self::Empty => formatter.write_str("ML dataset metadata text cannot be empty"),
196            Self::UnknownLabel => formatter.write_str("unknown ML dataset metadata label"),
197        }
198    }
199}
200
201impl Error for MlDatasetError {}
202
203fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlDatasetError> {
204    let trimmed = value.as_ref().trim();
205    if trimmed.is_empty() {
206        Err(MlDatasetError::Empty)
207    } else {
208        Ok(trimmed.to_string())
209    }
210}
211
212fn normalized_label(value: &str) -> Result<String, MlDatasetError> {
213    let trimmed = value.trim();
214    if trimmed.is_empty() {
215        Err(MlDatasetError::Empty)
216    } else {
217        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::{
224        MlDatasetError, MlDatasetKind, MlDatasetName, MlDatasetProvenance, MlDatasetSplit,
225        MlExampleKind,
226    };
227
228    #[test]
229    fn validates_dataset_names() -> Result<(), MlDatasetError> {
230        let name = MlDatasetName::new(" iris ")?;
231
232        assert_eq!(name.as_str(), "iris");
233        assert_eq!(name.to_string(), "iris");
234        assert_eq!("iris".parse::<MlDatasetName>()?, name);
235        Ok(())
236    }
237
238    #[test]
239    fn rejects_empty_dataset_names() {
240        assert_eq!(MlDatasetName::new("  "), Err(MlDatasetError::Empty));
241    }
242
243    #[test]
244    fn displays_and_parses_dataset_enums() -> Result<(), MlDatasetError> {
245        assert_eq!(MlDatasetSplit::Validation.to_string(), "validation");
246        assert_eq!(
247            "time_series".parse::<MlDatasetKind>()?,
248            MlDatasetKind::TimeSeries
249        );
250        assert_eq!(
251            "pseudo labeled".parse::<MlExampleKind>()?,
252            MlExampleKind::PseudoLabeled
253        );
254        assert_eq!(
255            "machine-generated".parse::<MlDatasetProvenance>()?,
256            MlDatasetProvenance::MachineGenerated
257        );
258        assert_eq!(
259            "shadow-2026".parse::<MlDatasetSplit>()?,
260            MlDatasetSplit::Custom("shadow-2026".to_string())
261        );
262        Ok(())
263    }
264}