Skip to main content

scirs2_datasets/
huggingface.rs

1//! HuggingFace dataset card metadata parsing and generation.
2//!
3//! This module provides support for HuggingFace dataset cards — the YAML
4//! frontmatter found in `README.md` files of HuggingFace Hub datasets.
5//!
6//! No external YAML crate is required; a minimal subset parser handles the
7//! specific fields used by HuggingFace dataset cards.
8//!
9//! ## Example
10//!
11//! ```rust
12//! use scirs2_datasets::huggingface::{parse_dataset_card, to_hf_card, card_to_readme};
13//!
14//! let yaml = "dataset_name: my-dataset\ntask_categories:\n  - text-classification\n";
15//! let card = parse_dataset_card(yaml).expect("parse ok");
16//! assert_eq!(card.dataset_name, "my-dataset");
17//!
18//! let card2 = to_hf_card("test-ds", 1000, "classification");
19//! let readme = card_to_readme(&card2);
20//! assert!(readme.contains("test-ds"));
21//! ```
22
23use std::io;
24use std::path::Path;
25
26// ─────────────────────────────────────────────────────────────────────────────
27// Public error type
28// ─────────────────────────────────────────────────────────────────────────────
29
30/// Errors that can occur when working with HuggingFace dataset cards.
31#[derive(Debug)]
32pub enum HfError {
33    /// I/O error while reading a file.
34    Io(io::Error),
35    /// Parsing error with a descriptive message.
36    Parse(String),
37    /// Required field is missing from the dataset card.
38    MissingField(&'static str),
39}
40
41impl std::fmt::Display for HfError {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        match self {
44            HfError::Io(e) => write!(f, "IO error: {e}"),
45            HfError::Parse(msg) => write!(f, "parse error: {msg}"),
46            HfError::MissingField(field) => write!(f, "missing field: {field}"),
47        }
48    }
49}
50
51impl std::error::Error for HfError {
52    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
53        match self {
54            HfError::Io(e) => Some(e),
55            _ => None,
56        }
57    }
58}
59
60impl From<io::Error> for HfError {
61    fn from(e: io::Error) -> Self {
62        HfError::Io(e)
63    }
64}
65
66// ─────────────────────────────────────────────────────────────────────────────
67// Public types
68// ─────────────────────────────────────────────────────────────────────────────
69
70/// Information about a single dataset split (train / validation / test …).
71#[derive(Debug, Clone, PartialEq)]
72pub struct HfSplitInfo {
73    /// Split name — typically `"train"`, `"test"`, or `"validation"`.
74    pub name: String,
75    /// Number of rows / examples in this split.
76    pub num_rows: usize,
77    /// Approximate size of this split in bytes.
78    pub num_bytes: usize,
79}
80
81/// HuggingFace dataset card metadata parsed from the YAML frontmatter in
82/// a `README.md` file.
83#[derive(Debug, Clone, PartialEq, Default)]
84pub struct HfDatasetCard {
85    /// Dataset identifier / slug, e.g. `"squad"`.
86    pub dataset_name: String,
87    /// HuggingFace task category strings, e.g. `["text-classification"]`.
88    pub task_categories: Vec<String>,
89    /// BCP-47 language codes, e.g. `["en", "fr"]`.
90    pub language: Vec<String>,
91    /// HuggingFace size category tags, e.g. `["1M<n<10M"]`.
92    pub size_categories: Vec<String>,
93    /// SPDX license identifier, e.g. `"apache-2.0"`.
94    pub license: Option<String>,
95    /// Human-readable dataset name that may differ from the slug.
96    pub pretty_name: Option<String>,
97    /// Per-split statistics (train, test, validation, …).
98    pub splits: Vec<HfSplitInfo>,
99}
100
101// ─────────────────────────────────────────────────────────────────────────────
102// Minimal YAML parser
103// ─────────────────────────────────────────────────────────────────────────────
104
105/// Parse a YAML value that appears as the rest of a `key: <rest>` line.
106fn parse_scalar(s: &str) -> String {
107    let s = s.trim();
108    // Strip surrounding quotes.
109    if s.len() >= 2
110        && ((s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')))
111    {
112        s[1..s.len() - 1].to_owned()
113    } else {
114        s.to_owned()
115    }
116}
117
118/// Return the number of leading ASCII space characters in `line`.
119fn indent_of(line: &str) -> usize {
120    line.len() - line.trim_start_matches(' ').len()
121}
122
123/// Find the first `:` character that is not inside a quoted string.
124fn find_colon(s: &str) -> Option<usize> {
125    let mut in_single = false;
126    let mut in_double = false;
127    for (i, c) in s.char_indices() {
128        match c {
129            '\'' if !in_double => in_single = !in_single,
130            '"' if !in_single => in_double = !in_double,
131            ':' if !in_single && !in_double => return Some(i),
132            _ => {}
133        }
134    }
135    None
136}
137
138/// Parse just the fields used in HuggingFace dataset cards from raw YAML text.
139///
140/// Handles:
141/// - `key: scalar` top-level entries
142/// - `key:\n  - item\n  - item` block lists (depth 1)
143/// - `key: [a, b, c]` inline lists
144///
145/// Returns a list of `(key, values)` pairs — values may be a list even for
146/// scalar entries (list of one element).
147fn parse_hf_yaml(yaml: &str) -> Vec<(String, Vec<String>)> {
148    let mut result: Vec<(String, Vec<String>)> = Vec::new();
149    let lines: Vec<&str> = yaml.lines().collect();
150    let mut i = 0;
151
152    while i < lines.len() {
153        let line = lines[i];
154        let trimmed = line.trim();
155
156        // Skip blank lines, comments, and YAML document markers.
157        if trimmed.is_empty() || trimmed.starts_with('#') || trimmed == "---" {
158            i += 1;
159            continue;
160        }
161
162        // Only process top-level keys (indent == 0).
163        if indent_of(line) != 0 {
164            i += 1;
165            continue;
166        }
167
168        if let Some(colon) = find_colon(line) {
169            let key = line[..colon].trim().to_owned();
170            let rest = line[colon + 1..].trim();
171
172            if rest.is_empty() {
173                // Value spans subsequent lines.
174                i += 1;
175                let mut items: Vec<String> = Vec::new();
176                while i < lines.len() {
177                    let sub = lines[i];
178                    let sub_trimmed = sub.trim();
179                    // Back to top-level — stop
180                    if !sub_trimmed.is_empty()
181                        && !sub_trimmed.starts_with('#')
182                        && indent_of(sub) == 0
183                    {
184                        break;
185                    }
186                    if let Some(rest) = sub_trimmed.strip_prefix("- ") {
187                        items.push(parse_scalar(rest));
188                    } else if sub_trimmed == "-" {
189                        items.push(String::new());
190                    }
191                    // Skip sub-key maps (splits, features); only collect list items.
192                    i += 1;
193                }
194                result.push((key, items));
195                continue;
196            } else if rest.starts_with('[') && rest.ends_with(']') {
197                // Inline list.
198                let inner = &rest[1..rest.len() - 1];
199                let items: Vec<String> = inner.split(',').map(parse_scalar).collect();
200                result.push((key, items));
201            } else {
202                result.push((key, vec![parse_scalar(rest)]));
203            }
204        }
205        i += 1;
206    }
207
208    result
209}
210
211/// Parse nested split blocks from YAML text.
212///
213/// Looks for:
214/// ```text
215/// splits:
216///   - name: train
217///     num_rows: 1000
218///     num_bytes: 8192
219/// ```
220fn parse_splits_from_yaml(yaml: &str) -> Vec<HfSplitInfo> {
221    let mut splits: Vec<HfSplitInfo> = Vec::new();
222    let lines: Vec<&str> = yaml.lines().collect();
223    let mut i = 0;
224
225    // Find "splits:" at indent 0
226    while i < lines.len() {
227        let line = lines[i];
228        let trimmed = line.trim();
229        if indent_of(line) == 0 && trimmed.starts_with("splits:") {
230            i += 1;
231            // Collect the block
232            while i < lines.len() {
233                let sub = lines[i];
234                let sub_trimmed = sub.trim();
235                if !sub_trimmed.is_empty() && !sub_trimmed.starts_with('#') && indent_of(sub) == 0 {
236                    break;
237                }
238                // New list item starting with "- name:" or just "-"
239                if sub_trimmed.starts_with("- name:") || sub_trimmed == "-" {
240                    let name_part = if let Some(rest) = sub_trimmed.strip_prefix("- name:") {
241                        parse_scalar(rest)
242                    } else {
243                        String::new()
244                    };
245                    let mut num_rows = 0usize;
246                    let mut num_bytes = 0usize;
247                    // Read sub-keys until next "- " at same indent or lower
248                    let item_indent = indent_of(sub);
249                    i += 1;
250                    while i < lines.len() {
251                        let inner = lines[i];
252                        let inner_trimmed = inner.trim();
253                        if inner_trimmed.is_empty() || inner_trimmed.starts_with('#') {
254                            i += 1;
255                            continue;
256                        }
257                        let inner_indent = indent_of(inner);
258                        // Back to parent block or next sibling
259                        if inner_indent <= item_indent
260                            && (inner_trimmed.starts_with('-') || inner_indent == 0)
261                        {
262                            break;
263                        }
264                        if let Some(colon) = find_colon(inner_trimmed) {
265                            let k = inner_trimmed[..colon].trim();
266                            let v = parse_scalar(&inner_trimmed[colon + 1..]);
267                            match k {
268                                "num_rows" => {
269                                    num_rows = v.parse().unwrap_or(0);
270                                }
271                                "num_bytes" => {
272                                    num_bytes = v.parse().unwrap_or(0);
273                                }
274                                _ => {}
275                            }
276                        }
277                        i += 1;
278                    }
279                    splits.push(HfSplitInfo {
280                        name: name_part,
281                        num_rows,
282                        num_bytes,
283                    });
284                } else {
285                    i += 1;
286                }
287            }
288            return splits;
289        }
290        i += 1;
291    }
292    splits
293}
294
295// ─────────────────────────────────────────────────────────────────────────────
296// Extract YAML frontmatter
297// ─────────────────────────────────────────────────────────────────────────────
298
299/// Extract the content between the first two `---` markers (YAML frontmatter).
300///
301/// Returns `None` if no frontmatter markers are present, in which case the
302/// caller should treat the entire input as raw YAML.
303fn extract_frontmatter(input: &str) -> Option<&str> {
304    // The input may start with "---\n" or have the front matter at a non-zero offset.
305    // Split on the literal "\n---\n" or "---\n" at position 0.
306    let input_trimmed = input.trim_start();
307    if !input_trimmed.starts_with("---") {
308        return None;
309    }
310    // Find the end of the opening "---" line.
311    let after_open = input_trimmed.find('\n').map(|p| p + 1)?;
312    let rest = &input_trimmed[after_open..];
313    // Find the closing "---" line.
314    let close = rest.find("\n---")?;
315    Some(&rest[..close])
316}
317
318// ─────────────────────────────────────────────────────────────────────────────
319// Public API
320// ─────────────────────────────────────────────────────────────────────────────
321
322/// Parse a HuggingFace dataset card from a YAML string.
323///
324/// The string may be either:
325/// - Raw YAML (no `---` delimiters), or
326/// - A full README.md string with YAML frontmatter between `---` markers.
327///
328/// Only the fields relevant to `HfDatasetCard` are extracted; unknown keys are
329/// silently ignored.
330///
331/// # Errors
332///
333/// Returns `HfError::Parse` if a required structural element is malformed.
334pub fn parse_dataset_card(yaml_str: &str) -> Result<HfDatasetCard, HfError> {
335    // Prefer frontmatter if present; otherwise treat as raw YAML.
336    let yaml_body = extract_frontmatter(yaml_str).unwrap_or(yaml_str);
337
338    let pairs = parse_hf_yaml(yaml_body);
339    let mut card = HfDatasetCard::default();
340
341    for (key, values) in &pairs {
342        match key.as_str() {
343            "dataset_name" => {
344                card.dataset_name = values.first().cloned().unwrap_or_default();
345            }
346            "task_categories" => {
347                card.task_categories = values.clone();
348            }
349            "language" => {
350                card.language = values.clone();
351            }
352            "size_categories" => {
353                card.size_categories = values.clone();
354            }
355            "license" => {
356                let s = values.first().cloned().unwrap_or_default();
357                if !s.is_empty() {
358                    card.license = Some(s);
359                }
360            }
361            "pretty_name" => {
362                let s = values.first().cloned().unwrap_or_default();
363                if !s.is_empty() {
364                    card.pretty_name = Some(s);
365                }
366            }
367            _ => {}
368        }
369    }
370
371    // Parse structured splits block separately (needs nested parsing).
372    card.splits = parse_splits_from_yaml(yaml_body);
373
374    Ok(card)
375}
376
377/// Discover and parse the dataset card from a local directory.
378///
379/// Searches for `README.md` in `dir` and parses its YAML frontmatter as an
380/// `HfDatasetCard`.
381///
382/// # Errors
383///
384/// - `HfError::Io` — directory or `README.md` file is not accessible.
385/// - `HfError::Parse` — frontmatter could not be parsed.
386/// - `HfError::MissingField` — `README.md` has no YAML frontmatter.
387pub fn load_dataset_card(dir: &Path) -> Result<HfDatasetCard, HfError> {
388    let readme_path = dir.join("README.md");
389    let content = std::fs::read_to_string(&readme_path)?;
390    if extract_frontmatter(&content).is_none() {
391        return Err(HfError::MissingField("YAML frontmatter (---) in README.md"));
392    }
393    parse_dataset_card(&content)
394}
395
396/// Build an `HfDatasetCard` from basic parameters.
397///
398/// This is a convenience constructor used when converting a SciRS2 dataset to
399/// a HuggingFace-compatible card.
400///
401/// * `name` — dataset slug
402/// * `n_rows` — number of training samples
403/// * `task` — HuggingFace task category string (e.g. `"classification"`)
404pub fn to_hf_card(name: &str, n_rows: usize, task: &str) -> HfDatasetCard {
405    let size_cat = size_category(n_rows);
406    HfDatasetCard {
407        dataset_name: name.to_owned(),
408        task_categories: vec![task.to_owned()],
409        language: vec!["en".to_owned()],
410        size_categories: vec![size_cat],
411        license: None,
412        pretty_name: Some(name.to_owned()),
413        splits: vec![HfSplitInfo {
414            name: "train".to_owned(),
415            num_rows: n_rows,
416            num_bytes: n_rows * 64, // rough estimate
417        }],
418    }
419}
420
421/// Render an `HfDatasetCard` as minimal HuggingFace `README.md` content.
422///
423/// The output has YAML frontmatter delimited by `---` markers followed by a
424/// brief Markdown body.
425pub fn card_to_readme(card: &HfDatasetCard) -> String {
426    let mut out = String::from("---\n");
427
428    out.push_str(&format!("dataset_name: {}\n", yaml_str(&card.dataset_name)));
429
430    if !card.task_categories.is_empty() {
431        out.push_str("task_categories:\n");
432        for tc in &card.task_categories {
433            out.push_str(&format!("  - {}\n", yaml_str(tc)));
434        }
435    }
436
437    if !card.language.is_empty() {
438        out.push_str("language:\n");
439        for lang in &card.language {
440            out.push_str(&format!("  - {}\n", yaml_str(lang)));
441        }
442    }
443
444    if !card.size_categories.is_empty() {
445        out.push_str("size_categories:\n");
446        for sc in &card.size_categories {
447            out.push_str(&format!("  - {}\n", yaml_str(sc)));
448        }
449    }
450
451    if let Some(ref lic) = card.license {
452        out.push_str(&format!("license: {}\n", yaml_str(lic)));
453    }
454
455    if let Some(ref pn) = card.pretty_name {
456        out.push_str(&format!("pretty_name: {}\n", yaml_str(pn)));
457    }
458
459    if !card.splits.is_empty() {
460        out.push_str("splits:\n");
461        for split in &card.splits {
462            out.push_str(&format!(
463                "  - name: {}\n    num_rows: {}\n    num_bytes: {}\n",
464                yaml_str(&split.name),
465                split.num_rows,
466                split.num_bytes,
467            ));
468        }
469    }
470
471    out.push_str("---\n\n");
472    out.push_str(&format!("# {}\n\n", card.dataset_name));
473
474    if let Some(ref pn) = card.pretty_name {
475        out.push_str(&format!("{}\n\n", pn));
476    }
477
478    if !card.task_categories.is_empty() {
479        out.push_str(&format!("Tasks: {}\n", card.task_categories.join(", ")));
480    }
481
482    out
483}
484
485// ─────────────────────────────────────────────────────────────────────────────
486// Helpers
487// ─────────────────────────────────────────────────────────────────────────────
488
489/// Return the HuggingFace size category tag for a number of rows.
490fn size_category(n: usize) -> String {
491    match n {
492        0..=999 => "n<1K".to_owned(),
493        1_000..=9_999 => "1K<n<10K".to_owned(),
494        10_000..=99_999 => "10K<n<100K".to_owned(),
495        100_000..=999_999 => "100K<n<1M".to_owned(),
496        1_000_000..=9_999_999 => "1M<n<10M".to_owned(),
497        _ => "10M<n<100M".to_owned(),
498    }
499}
500
501/// Escape a YAML string value if it contains characters requiring quoting.
502fn yaml_str(s: &str) -> String {
503    if s.contains(':') || s.contains('#') || s.contains('"') || s.contains('\'') {
504        format!("\"{}\"", s.replace('"', "\\\""))
505    } else {
506        s.to_owned()
507    }
508}
509
510// ─────────────────────────────────────────────────────────────────────────────
511// Tests
512// ─────────────────────────────────────────────────────────────────────────────
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517    use std::io::Write;
518
519    /// Sample YAML string covering all key fields.
520    const SAMPLE_YAML: &str = "\
521dataset_name: squad
522task_categories:
523  - question-answering
524language:
525  - en
526size_categories:
527  - 100K<n<1M
528license: cc-by-4.0
529pretty_name: Stanford Question Answering Dataset
530splits:
531  - name: train
532    num_rows: 87599
533    num_bytes: 29344551
534  - name: validation
535    num_rows: 10570
536    num_bytes: 3519936
537";
538
539    // 1. parse_dataset_card parses a sample YAML string correctly
540    #[test]
541    fn test_parse_dataset_card_basic() {
542        let card = parse_dataset_card(SAMPLE_YAML).expect("should parse");
543        assert_eq!(card.dataset_name, "squad");
544        assert_eq!(card.task_categories, vec!["question-answering"]);
545        assert_eq!(card.language, vec!["en"]);
546        assert_eq!(card.size_categories, vec!["100K<n<1M"]);
547        assert_eq!(card.license, Some("cc-by-4.0".to_owned()));
548        assert_eq!(
549            card.pretty_name,
550            Some("Stanford Question Answering Dataset".to_owned())
551        );
552    }
553
554    // 2. parse_dataset_card parses splits correctly
555    #[test]
556    fn test_parse_splits() {
557        let card = parse_dataset_card(SAMPLE_YAML).expect("should parse");
558        assert_eq!(card.splits.len(), 2);
559        assert_eq!(card.splits[0].name, "train");
560        assert_eq!(card.splits[0].num_rows, 87599);
561        assert_eq!(card.splits[0].num_bytes, 29344551);
562        assert_eq!(card.splits[1].name, "validation");
563        assert_eq!(card.splits[1].num_rows, 10570);
564    }
565
566    // 3. to_hf_card creates a card with correct n_rows
567    #[test]
568    fn test_to_hf_card_n_rows() {
569        let card = to_hf_card("my-ds", 5000, "classification");
570        assert_eq!(card.dataset_name, "my-ds");
571        assert_eq!(card.task_categories, vec!["classification"]);
572        assert!(!card.splits.is_empty());
573        let train_split = card.splits.iter().find(|s| s.name == "train");
574        assert!(train_split.is_some(), "should have a train split");
575        assert_eq!(train_split.expect("verified above").num_rows, 5000);
576    }
577
578    // 4. card_to_readme contains the dataset name
579    #[test]
580    fn test_card_to_readme_contains_name() {
581        let card = to_hf_card("awesome-dataset", 100, "text-classification");
582        let readme = card_to_readme(&card);
583        assert!(
584            readme.contains("awesome-dataset"),
585            "README should contain the dataset name"
586        );
587    }
588
589    // 5. load_dataset_card returns Err for non-existent directory
590    #[test]
591    fn test_load_dataset_card_nonexistent() {
592        let result = load_dataset_card(Path::new("/nonexistent/path/that/does/not/exist"));
593        assert!(result.is_err(), "should fail for non-existent path");
594    }
595
596    // 6. card_to_readme -> parse_dataset_card round-trip preserves dataset_name
597    #[test]
598    fn test_roundtrip_dataset_name() {
599        let original = to_hf_card("roundtrip-test", 2000, "regression");
600        let readme = card_to_readme(&original);
601        let parsed = parse_dataset_card(&readme).expect("round-trip parse should succeed");
602        assert_eq!(
603            parsed.dataset_name, original.dataset_name,
604            "dataset_name should survive round-trip"
605        );
606    }
607
608    // 7. load_dataset_card reads a real README.md from a temp directory
609    #[test]
610    fn test_load_dataset_card_from_temp_dir() {
611        let tmp_dir = std::env::temp_dir().join("scirs2_hf_test_load_card");
612        std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
613
614        let yaml_fm = "---\ndataset_name: temp-dataset\ntask_categories:\n  - classification\nlanguage:\n  - en\n---\n# temp-dataset\n";
615        let readme_path = tmp_dir.join("README.md");
616        let mut f = std::fs::File::create(&readme_path).expect("create README.md");
617        f.write_all(yaml_fm.as_bytes()).expect("write");
618
619        let card = load_dataset_card(&tmp_dir).expect("load card");
620        assert_eq!(card.dataset_name, "temp-dataset");
621        assert_eq!(card.task_categories, vec!["classification"]);
622
623        // Cleanup
624        let _ = std::fs::remove_file(&readme_path);
625        let _ = std::fs::remove_dir(&tmp_dir);
626    }
627
628    // 8. load_dataset_card returns MissingField error when README has no frontmatter
629    #[test]
630    fn test_load_dataset_card_no_frontmatter() {
631        let tmp_dir = std::env::temp_dir().join("scirs2_hf_test_no_fm");
632        std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
633
634        let readme_path = tmp_dir.join("README.md");
635        let mut f = std::fs::File::create(&readme_path).expect("create README.md");
636        f.write_all(b"# Plain README\n\nNo frontmatter here.\n")
637            .expect("write");
638
639        let result = load_dataset_card(&tmp_dir);
640        assert!(
641            matches!(result, Err(HfError::MissingField(_))),
642            "expected MissingField, got: {:?}",
643            result
644        );
645
646        let _ = std::fs::remove_file(&readme_path);
647        let _ = std::fs::remove_dir(&tmp_dir);
648    }
649
650    // 9. size_category helper returns expected values
651    #[test]
652    fn test_size_categories() {
653        assert_eq!(size_category(500), "n<1K");
654        assert_eq!(size_category(5000), "1K<n<10K");
655        assert_eq!(size_category(50_000), "10K<n<100K");
656        assert_eq!(size_category(500_000), "100K<n<1M");
657        assert_eq!(size_category(5_000_000), "1M<n<10M");
658        assert_eq!(size_category(50_000_000), "10M<n<100M");
659    }
660
661    // 10. parse_dataset_card handles inline list syntax
662    #[test]
663    fn test_parse_inline_list() {
664        let yaml = "dataset_name: inline-test\nlanguage: [en, fr, de]\n";
665        let card = parse_dataset_card(yaml).expect("parse");
666        assert_eq!(card.dataset_name, "inline-test");
667        assert_eq!(card.language, vec!["en", "fr", "de"]);
668    }
669}