synth_ai_core/container/
datasets.rs1use crate::errors::CoreError;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TaskDatasetSpec {
6 pub id: String,
7 pub name: String,
8 #[serde(default)]
9 pub version: Option<String>,
10 #[serde(default)]
11 pub splits: Vec<String>,
12 #[serde(default)]
13 pub default_split: Option<String>,
14 #[serde(default)]
15 pub cardinality: Option<i64>,
16 #[serde(default)]
17 pub description: Option<String>,
18}
19
20impl TaskDatasetSpec {
21 pub fn validate(&self) -> Result<(), CoreError> {
22 if let Some(default_split) = &self.default_split {
23 if !self.splits.is_empty() && !self.splits.contains(default_split) {
24 return Err(CoreError::InvalidInput(
25 "default_split must be one of splits when provided".to_string(),
26 ));
27 }
28 }
29 Ok(())
30 }
31
32 pub fn merge_with(&self, other: &TaskDatasetSpec) -> TaskDatasetSpec {
33 TaskDatasetSpec {
34 id: if other.id.is_empty() {
35 self.id.clone()
36 } else {
37 other.id.clone()
38 },
39 name: if other.name.is_empty() {
40 self.name.clone()
41 } else {
42 other.name.clone()
43 },
44 version: other.version.clone().or_else(|| self.version.clone()),
45 splits: if other.splits.is_empty() {
46 self.splits.clone()
47 } else {
48 other.splits.clone()
49 },
50 default_split: other
51 .default_split
52 .clone()
53 .or_else(|| self.default_split.clone()),
54 cardinality: other.cardinality.or(self.cardinality),
55 description: other
56 .description
57 .clone()
58 .or_else(|| self.description.clone()),
59 }
60 }
61}
62
63pub fn ensure_split(spec: &TaskDatasetSpec, split: Option<&str>) -> Result<String, CoreError> {
64 if spec.splits.is_empty() {
65 return Ok(split
66 .unwrap_or_else(|| spec.default_split.as_deref().unwrap_or("default"))
67 .to_string());
68 }
69 match split {
70 Some(value) => {
71 if spec.splits.contains(&value.to_string()) {
72 Ok(value.to_string())
73 } else {
74 Err(CoreError::InvalidInput(format!(
75 "Unknown split '{}' for dataset {}",
76 value, spec.id
77 )))
78 }
79 }
80 None => {
81 if let Some(default_split) = &spec.default_split {
82 Ok(default_split.clone())
83 } else {
84 Err(CoreError::InvalidInput(format!(
85 "split must be provided for dataset {}",
86 spec.id
87 )))
88 }
89 }
90 }
91}
92
93pub fn normalise_seed(seed: i64, cardinality: Option<i64>) -> i64 {
94 let mut value = seed;
95 if value < 0 {
96 value = value.abs();
97 }
98 if let Some(card) = cardinality {
99 if card > 0 {
100 value = value % card;
101 }
102 }
103 value
104}