1use crate::error::{DatasetsError, Result};
19use std::collections::HashMap;
20
21#[non_exhaustive]
27#[derive(Debug, Clone, PartialEq)]
28pub enum TaskCategory {
29 TextClassification,
31 TokenClassification,
33 QuestionAnswering,
35 Summarization,
37 Translation,
39 TextGeneration,
41 ImageClassification,
43 ObjectDetection,
45 Other(String),
47}
48
49impl TaskCategory {
50 pub fn as_str(&self) -> &str {
52 match self {
53 Self::TextClassification => "text-classification",
54 Self::TokenClassification => "token-classification",
55 Self::QuestionAnswering => "question-answering",
56 Self::Summarization => "summarization",
57 Self::Translation => "translation",
58 Self::TextGeneration => "text-generation",
59 Self::ImageClassification => "image-classification",
60 Self::ObjectDetection => "object-detection",
61 Self::Other(s) => s.as_str(),
62 }
63 }
64
65 #[allow(clippy::should_implement_trait)]
67 pub fn from_str(s: &str) -> Self {
68 match s.trim() {
69 "text-classification" => Self::TextClassification,
70 "token-classification" => Self::TokenClassification,
71 "question-answering" => Self::QuestionAnswering,
72 "summarization" => Self::Summarization,
73 "translation" => Self::Translation,
74 "text-generation" => Self::TextGeneration,
75 "image-classification" => Self::ImageClassification,
76 "object-detection" => Self::ObjectDetection,
77 other => Self::Other(other.to_owned()),
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq)]
84pub struct SplitInfo {
85 pub name: String,
87 pub num_bytes: u64,
89 pub num_examples: u64,
91}
92
93#[non_exhaustive]
95#[derive(Debug, Clone, PartialEq)]
96pub enum FeatureDtype {
97 String,
99 Int32,
101 Int64,
103 Float32,
105 Float64,
107 Bool,
109 Image,
111 Audio,
113}
114
115impl FeatureDtype {
116 pub fn as_str(&self) -> &str {
118 match self {
119 Self::String => "string",
120 Self::Int32 => "int32",
121 Self::Int64 => "int64",
122 Self::Float32 => "float32",
123 Self::Float64 => "float64",
124 Self::Bool => "bool",
125 Self::Image => "image",
126 Self::Audio => "audio",
127 }
128 }
129
130 #[allow(clippy::should_implement_trait)]
132 pub fn from_str(s: &str) -> Self {
133 match s.trim() {
134 "string" | "str" => Self::String,
135 "int32" => Self::Int32,
136 "int64" | "int" => Self::Int64,
137 "float32" | "float" => Self::Float32,
138 "float64" | "double" => Self::Float64,
139 "bool" | "boolean" => Self::Bool,
140 "image" => Self::Image,
141 "audio" => Self::Audio,
142 _ => Self::String,
143 }
144 }
145}
146
147#[derive(Debug, Clone, PartialEq)]
149pub struct FeatureInfo {
150 pub name: String,
152 pub dtype: FeatureDtype,
154 pub description: Option<String>,
156}
157
158#[derive(Debug, Clone, PartialEq, Default)]
160pub struct DatasetCard {
161 pub name: String,
163 pub description: String,
165 pub license: Option<String>,
167 pub language: Vec<String>,
169 pub tags: Vec<String>,
171 pub task_categories: Vec<TaskCategory>,
173 pub splits: Vec<SplitInfo>,
175 pub features: Vec<FeatureInfo>,
177 pub citation: Option<String>,
179}
180
181#[derive(Debug, Clone, PartialEq)]
187pub enum YamlValue {
188 Str(String),
190 Int(i64),
192 Float(f64),
194 List(Vec<YamlValue>),
196 Map(HashMap<String, YamlValue>),
198 Bool(bool),
200 Null,
202}
203
204impl YamlValue {
205 pub fn as_str(&self) -> Option<&str> {
207 if let YamlValue::Str(s) = self {
208 Some(s.as_str())
209 } else {
210 None
211 }
212 }
213
214 pub fn as_i64(&self) -> Option<i64> {
216 if let YamlValue::Int(n) = self {
217 Some(*n)
218 } else {
219 None
220 }
221 }
222
223 pub fn as_u64(&self) -> Option<u64> {
225 match self {
226 YamlValue::Int(n) => Some(*n as u64),
227 YamlValue::Str(s) => s.trim().parse().ok(),
228 _ => None,
229 }
230 }
231
232 pub fn as_list(&self) -> Option<&Vec<YamlValue>> {
234 if let YamlValue::List(v) = self {
235 Some(v)
236 } else {
237 None
238 }
239 }
240
241 pub fn as_map(&self) -> Option<&HashMap<String, YamlValue>> {
243 if let YamlValue::Map(m) = self {
244 Some(m)
245 } else {
246 None
247 }
248 }
249}
250
251pub fn simple_yaml_parse(s: &str) -> HashMap<String, YamlValue> {
263 let mut result: HashMap<String, YamlValue> = HashMap::new();
264 let lines: Vec<&str> = s.lines().collect();
265 let mut i = 0usize;
266
267 while i < lines.len() {
268 let line = lines[i];
269 let trimmed = line.trim();
271 if trimmed.is_empty() || trimmed.starts_with('#') {
272 i += 1;
273 continue;
274 }
275
276 if let Some(colon_pos) = find_colon(line) {
278 let indent = leading_spaces(line);
279 if indent == 0 {
280 let key = line[..colon_pos].trim().to_owned();
281 let rest = line[colon_pos + 1..].trim();
282
283 if rest.is_empty() {
284 i += 1;
286 let mut block_items: Vec<YamlValue> = Vec::new();
288 let mut sub_map: HashMap<String, YamlValue> = HashMap::new();
289 let mut is_list = false;
290 let mut is_map = false;
291
292 while i < lines.len() {
293 let sub_line = lines[i];
294 let sub_trimmed = sub_line.trim();
295 let sub_indent = leading_spaces(sub_line);
296
297 if sub_trimmed.is_empty() || sub_trimmed.starts_with('#') {
298 i += 1;
299 continue;
300 }
301 if sub_indent == 0 {
303 break;
304 }
305 if sub_trimmed.starts_with("- ") || sub_trimmed == "-" {
307 is_list = true;
308 let item_str = if sub_trimmed.len() > 2 {
309 sub_trimmed[2..].trim()
310 } else {
311 ""
312 };
313 block_items.push(parse_scalar(item_str));
314 i += 1;
315 } else if let Some(sub_colon) = find_colon(sub_trimmed) {
316 is_map = true;
317 let sub_key = sub_trimmed[..sub_colon].trim().to_owned();
318 let sub_val = sub_trimmed[sub_colon + 1..].trim();
319 sub_map.insert(sub_key, parse_scalar(sub_val));
320 i += 1;
321 } else {
322 i += 1;
323 }
324 }
325
326 let value = if is_list {
327 YamlValue::List(block_items)
328 } else if is_map {
329 YamlValue::Map(sub_map)
330 } else {
331 YamlValue::Null
332 };
333 result.insert(key, value);
334 } else if rest.starts_with('[') && rest.ends_with(']') {
335 let inner = &rest[1..rest.len() - 1];
337 let items: Vec<YamlValue> =
338 inner.split(',').map(|s| parse_scalar(s.trim())).collect();
339 result.insert(key, YamlValue::List(items));
340 i += 1;
341 } else {
342 result.insert(key, parse_scalar(rest));
343 i += 1;
344 }
345 continue;
346 }
347 }
348 i += 1;
349 }
350
351 result
352}
353
354fn leading_spaces(s: &str) -> usize {
355 s.len() - s.trim_start().len()
356}
357
358fn find_colon(s: &str) -> Option<usize> {
359 let bytes = s.as_bytes();
360 for (i, &b) in bytes.iter().enumerate() {
361 if b == b':' {
362 return Some(i);
364 }
365 }
366 None
367}
368
369fn parse_scalar(s: &str) -> YamlValue {
370 let s = s.trim();
371 if s.is_empty() || s == "null" || s == "~" {
372 return YamlValue::Null;
373 }
374 if s == "true" {
375 return YamlValue::Bool(true);
376 }
377 if s == "false" {
378 return YamlValue::Bool(false);
379 }
380 let unquoted = strip_quotes(s);
382 if let Ok(n) = unquoted.parse::<i64>() {
384 return YamlValue::Int(n);
385 }
386 if let Ok(f) = unquoted.parse::<f64>() {
388 return YamlValue::Float(f);
389 }
390 YamlValue::Str(unquoted.to_owned())
391}
392
393fn strip_quotes(s: &str) -> &str {
394 if (s.starts_with('"') && s.ends_with('"')) || (s.starts_with('\'') && s.ends_with('\'')) {
395 &s[1..s.len() - 1]
396 } else {
397 s
398 }
399}
400
401fn extract_front_matter(input: &str) -> Option<&str> {
407 let mut lines = input.splitn(3, "---");
408 let _before = lines.next()?;
410 let front = lines.next()?;
411 Some(front)
412}
413
414pub fn parse_dataset_card(input: &str) -> Result<DatasetCard> {
424 let yaml_src = match extract_front_matter(input) {
425 Some(fm) => fm,
426 None => input,
427 };
428
429 let map = simple_yaml_parse(yaml_src);
430 let mut card = DatasetCard::default();
431
432 if let Some(v) = map.get("name") {
433 card.name = v.as_str().unwrap_or_default().to_owned();
434 }
435 if let Some(v) = map.get("description") {
436 card.description = v.as_str().unwrap_or_default().to_owned();
437 }
438 if let Some(v) = map.get("license") {
439 let s = v.as_str().unwrap_or_default().to_owned();
440 if !s.is_empty() {
441 card.license = Some(s);
442 }
443 }
444 if let Some(v) = map.get("language") {
445 card.language = collect_str_list(v);
446 }
447 if let Some(v) = map.get("tags") {
448 card.tags = collect_str_list(v);
449 }
450 if let Some(v) = map.get("task_categories") {
451 card.task_categories = collect_str_list(v)
452 .into_iter()
453 .map(|s| TaskCategory::from_str(&s))
454 .collect();
455 }
456 if let Some(v) = map.get("splits") {
457 card.splits = parse_splits(v);
458 }
459 if let Some(v) = map.get("features") {
460 card.features = parse_features(v);
461 }
462 if let Some(v) = map.get("citation") {
463 let s = v.as_str().unwrap_or_default().to_owned();
464 if !s.is_empty() {
465 card.citation = Some(s);
466 }
467 }
468
469 Ok(card)
470}
471
472fn collect_str_list(v: &YamlValue) -> Vec<String> {
474 match v {
475 YamlValue::List(items) => items
476 .iter()
477 .filter_map(|item| item.as_str().map(str::to_owned))
478 .collect(),
479 YamlValue::Str(s) => vec![s.clone()],
480 _ => Vec::new(),
481 }
482}
483
484fn parse_splits(v: &YamlValue) -> Vec<SplitInfo> {
486 let items = match v.as_list() {
487 Some(l) => l,
488 None => return Vec::new(),
489 };
490 items
491 .iter()
492 .filter_map(|item| {
493 let m = item.as_map()?;
494 let name = m.get("name")?.as_str()?.to_owned();
495 let num_bytes = m.get("num_bytes").and_then(|v| v.as_u64()).unwrap_or(0);
496 let num_examples = m.get("num_examples").and_then(|v| v.as_u64()).unwrap_or(0);
497 Some(SplitInfo {
498 name,
499 num_bytes,
500 num_examples,
501 })
502 })
503 .collect()
504}
505
506fn parse_features(v: &YamlValue) -> Vec<FeatureInfo> {
508 let items = match v.as_list() {
509 Some(l) => l,
510 None => return Vec::new(),
511 };
512 items
513 .iter()
514 .filter_map(|item| {
515 let m = item.as_map()?;
516 let name = m.get("name")?.as_str()?.to_owned();
517 let dtype_str = m.get("dtype").and_then(|v| v.as_str()).unwrap_or("string");
518 let dtype = FeatureDtype::from_str(dtype_str);
519 let description = m
520 .get("description")
521 .and_then(|v| v.as_str())
522 .map(str::to_owned);
523 Some(FeatureInfo {
524 name,
525 dtype,
526 description,
527 })
528 })
529 .collect()
530}
531
532pub fn write_dataset_card(card: &DatasetCard) -> String {
534 let mut out = String::from("---\n");
535
536 out.push_str(&format!("name: {}\n", yaml_escape(&card.name)));
537 out.push_str(&format!(
538 "description: {}\n",
539 yaml_escape(&card.description)
540 ));
541
542 if let Some(ref lic) = card.license {
543 out.push_str(&format!("license: {}\n", yaml_escape(lic)));
544 }
545
546 if !card.language.is_empty() {
547 out.push_str("language:\n");
548 for lang in &card.language {
549 out.push_str(&format!(" - {}\n", yaml_escape(lang)));
550 }
551 }
552
553 if !card.tags.is_empty() {
554 out.push_str("tags:\n");
555 for tag in &card.tags {
556 out.push_str(&format!(" - {}\n", yaml_escape(tag)));
557 }
558 }
559
560 if !card.task_categories.is_empty() {
561 out.push_str("task_categories:\n");
562 for tc in &card.task_categories {
563 out.push_str(&format!(" - {}\n", yaml_escape(tc.as_str())));
564 }
565 }
566
567 if !card.splits.is_empty() {
568 out.push_str("splits:\n");
569 for split in &card.splits {
570 out.push_str(&format!(
571 " - name: {}\n num_bytes: {}\n num_examples: {}\n",
572 yaml_escape(&split.name),
573 split.num_bytes,
574 split.num_examples
575 ));
576 }
577 }
578
579 if !card.features.is_empty() {
580 out.push_str("features:\n");
581 for feat in &card.features {
582 out.push_str(&format!(
583 " - name: {}\n dtype: {}\n",
584 yaml_escape(&feat.name),
585 feat.dtype.as_str()
586 ));
587 if let Some(ref desc) = feat.description {
588 out.push_str(&format!(" description: {}\n", yaml_escape(desc)));
589 }
590 }
591 }
592
593 if let Some(ref cit) = card.citation {
594 out.push_str(&format!("citation: {}\n", yaml_escape(cit)));
595 }
596
597 out.push_str("---\n");
598 out
599}
600
601fn yaml_escape(s: &str) -> String {
603 if s.contains(':') || s.contains('#') || s.contains('\'') || s.contains('"') {
604 format!("\"{}\"", s.replace('"', "\\\""))
605 } else {
606 s.to_owned()
607 }
608}
609
610pub fn validate_card(card: &DatasetCard) -> Vec<String> {
614 let mut warnings = Vec::new();
615
616 if card.name.trim().is_empty() {
617 warnings.push("'name' is empty".to_owned());
618 }
619 if card.description.trim().is_empty() {
620 warnings.push("'description' is empty".to_owned());
621 }
622 if card.language.is_empty() {
623 warnings.push(
624 "'language' list is empty; consider specifying at least one language code".to_owned(),
625 );
626 }
627 if card.task_categories.is_empty() {
628 warnings.push("'task_categories' is empty; consider specifying the task type".to_owned());
629 }
630 if card.splits.is_empty() {
631 warnings.push("'splits' is empty; consider documenting train/test splits".to_owned());
632 }
633 for split in &card.splits {
634 if split.name.trim().is_empty() {
635 warnings.push("A split has an empty 'name'".to_owned());
636 }
637 if split.num_examples == 0 {
638 warnings.push(format!("Split '{}' has num_examples == 0", split.name));
639 }
640 }
641 for feat in &card.features {
642 if feat.name.trim().is_empty() {
643 warnings.push("A feature has an empty 'name'".to_owned());
644 }
645 }
646
647 warnings
648}
649
650#[cfg(test)]
655mod tests {
656 use super::*;
657
658 const SIMPLE_CARD: &str = "---\nname: test\nlanguage:\n - en\n---\n";
659
660 #[test]
661 fn test_parse_simple_card() {
662 let card = parse_dataset_card(SIMPLE_CARD).expect("should parse");
663 assert_eq!(card.name, "test");
664 assert_eq!(card.language, vec!["en".to_owned()]);
665 }
666
667 #[test]
668 fn test_parse_full_card() {
669 let yaml = "---\n\
670 name: my-dataset\n\
671 description: A comprehensive dataset\n\
672 license: apache-2.0\n\
673 language:\n - en\n - fr\n\
674 tags:\n - nlp\n - benchmark\n\
675 task_categories:\n - text-classification\n\
676 ---\n";
677 let card = parse_dataset_card(yaml).expect("parse");
678 assert_eq!(card.name, "my-dataset");
679 assert_eq!(card.description, "A comprehensive dataset");
680 assert_eq!(card.license, Some("apache-2.0".into()));
681 assert_eq!(card.language, vec!["en", "fr"]);
682 assert_eq!(card.tags, vec!["nlp", "benchmark"]);
683 assert_eq!(card.task_categories, vec![TaskCategory::TextClassification]);
684 }
685
686 #[test]
687 fn test_write_roundtrip() {
688 let card = DatasetCard {
689 name: "roundtrip-test".into(),
690 description: "Test description".into(),
691 license: Some("mit".into()),
692 language: vec!["en".into()],
693 tags: vec!["test".into()],
694 task_categories: vec![TaskCategory::Summarization],
695 splits: vec![SplitInfo {
696 name: "train".into(),
697 num_bytes: 1024,
698 num_examples: 100,
699 }],
700 features: vec![FeatureInfo {
701 name: "text".into(),
702 dtype: FeatureDtype::String,
703 description: None,
704 }],
705 citation: None,
706 };
707
708 let rendered = write_dataset_card(&card);
709 assert!(rendered.contains("name: roundtrip-test"));
710 assert!(rendered.contains("license: mit"));
711 assert!(rendered.contains("num_examples: 100"));
712
713 let parsed = parse_dataset_card(&rendered).expect("reparse");
715 assert_eq!(parsed.name, card.name);
716 assert_eq!(parsed.license, card.license);
717 assert_eq!(parsed.language, card.language);
718 }
719
720 #[test]
721 fn test_validate_empty_name() {
722 let card = DatasetCard {
723 name: String::new(),
724 ..Default::default()
725 };
726 let warnings = validate_card(&card);
727 assert!(
728 warnings.iter().any(|w| w.contains("'name' is empty")),
729 "expected warning about empty name, got: {warnings:?}"
730 );
731 }
732
733 #[test]
734 fn test_validate_valid_card() {
735 let card = DatasetCard {
736 name: "good".into(),
737 description: "good desc".into(),
738 language: vec!["en".into()],
739 task_categories: vec![TaskCategory::TextClassification],
740 splits: vec![SplitInfo {
741 name: "train".into(),
742 num_bytes: 100,
743 num_examples: 10,
744 }],
745 ..Default::default()
746 };
747 let warnings = validate_card(&card);
748 assert!(warnings.is_empty(), "unexpected warnings: {warnings:?}");
749 }
750
751 #[test]
752 fn test_task_category_roundtrip() {
753 let cats = vec![
754 TaskCategory::TextClassification,
755 TaskCategory::TokenClassification,
756 TaskCategory::QuestionAnswering,
757 TaskCategory::Summarization,
758 TaskCategory::Translation,
759 TaskCategory::TextGeneration,
760 TaskCategory::ImageClassification,
761 TaskCategory::ObjectDetection,
762 TaskCategory::Other("custom-task".into()),
763 ];
764 for cat in cats {
765 let s = cat.as_str();
766 let parsed = TaskCategory::from_str(s);
767 assert_eq!(parsed, cat, "roundtrip failed for {s}");
768 }
769 }
770
771 #[test]
772 fn test_feature_dtype_roundtrip() {
773 let dtypes = vec![
774 FeatureDtype::String,
775 FeatureDtype::Int32,
776 FeatureDtype::Int64,
777 FeatureDtype::Float32,
778 FeatureDtype::Float64,
779 FeatureDtype::Bool,
780 FeatureDtype::Image,
781 FeatureDtype::Audio,
782 ];
783 for dt in dtypes {
784 let s = dt.as_str();
785 let parsed = FeatureDtype::from_str(s);
786 assert_eq!(parsed, dt);
787 }
788 }
789
790 #[test]
791 fn test_inline_list_parsing() {
792 let yaml = "tags: [nlp, vision, audio]\n";
793 let map = simple_yaml_parse(yaml);
794 if let Some(YamlValue::List(items)) = map.get("tags") {
795 assert_eq!(items.len(), 3);
796 } else {
797 panic!("expected list");
798 }
799 }
800
801 #[test]
802 fn test_yaml_scalar_types() {
803 let yaml = "count: 42\nrate: 3.14\nflag: true\nempty: null\n";
804 let map = simple_yaml_parse(yaml);
805 assert_eq!(map.get("count"), Some(&YamlValue::Int(42)));
806 assert!(matches!(map.get("rate"), Some(YamlValue::Float(_))));
807 assert_eq!(map.get("flag"), Some(&YamlValue::Bool(true)));
808 assert_eq!(map.get("empty"), Some(&YamlValue::Null));
809 }
810}