Skip to main content

synth_ai_core/container/
datasets.rs

1use crate::errors::CoreError;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct TaskDatasetSpec {
6    pub id: String,
7    pub name: String,
8    #[serde(default)]
9    pub version: Option<String>,
10    #[serde(default)]
11    pub splits: Vec<String>,
12    #[serde(default)]
13    pub default_split: Option<String>,
14    #[serde(default)]
15    pub cardinality: Option<i64>,
16    #[serde(default)]
17    pub description: Option<String>,
18}
19
20impl TaskDatasetSpec {
21    pub fn validate(&self) -> Result<(), CoreError> {
22        if let Some(default_split) = &self.default_split {
23            if !self.splits.is_empty() && !self.splits.contains(default_split) {
24                return Err(CoreError::InvalidInput(
25                    "default_split must be one of splits when provided".to_string(),
26                ));
27            }
28        }
29        Ok(())
30    }
31
32    pub fn merge_with(&self, other: &TaskDatasetSpec) -> TaskDatasetSpec {
33        TaskDatasetSpec {
34            id: if other.id.is_empty() {
35                self.id.clone()
36            } else {
37                other.id.clone()
38            },
39            name: if other.name.is_empty() {
40                self.name.clone()
41            } else {
42                other.name.clone()
43            },
44            version: other.version.clone().or_else(|| self.version.clone()),
45            splits: if other.splits.is_empty() {
46                self.splits.clone()
47            } else {
48                other.splits.clone()
49            },
50            default_split: other
51                .default_split
52                .clone()
53                .or_else(|| self.default_split.clone()),
54            cardinality: other.cardinality.or(self.cardinality),
55            description: other
56                .description
57                .clone()
58                .or_else(|| self.description.clone()),
59        }
60    }
61}
62
63pub fn ensure_split(spec: &TaskDatasetSpec, split: Option<&str>) -> Result<String, CoreError> {
64    if spec.splits.is_empty() {
65        return Ok(split
66            .unwrap_or_else(|| spec.default_split.as_deref().unwrap_or("default"))
67            .to_string());
68    }
69    match split {
70        Some(value) => {
71            if spec.splits.contains(&value.to_string()) {
72                Ok(value.to_string())
73            } else {
74                Err(CoreError::InvalidInput(format!(
75                    "Unknown split '{}' for dataset {}",
76                    value, spec.id
77                )))
78            }
79        }
80        None => {
81            if let Some(default_split) = &spec.default_split {
82                Ok(default_split.clone())
83            } else {
84                Err(CoreError::InvalidInput(format!(
85                    "split must be provided for dataset {}",
86                    spec.id
87                )))
88            }
89        }
90    }
91}
92
93pub fn normalise_seed(seed: i64, cardinality: Option<i64>) -> i64 {
94    let mut value = seed;
95    if value < 0 {
96        value = value.abs();
97    }
98    if let Some(card) = cardinality {
99        if card > 0 {
100            value = value % card;
101        }
102    }
103    value
104}