Skip to main content

synth_claw/config/
schema.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::path::PathBuf;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct SynthConfig {
7    pub name: String,
8    #[serde(default)]
9    pub source: Option<SourceConfig>,
10    pub provider: ProviderConfig,
11    pub generation: GenerationConfig,
12    pub output: OutputConfig,
13    #[serde(default)]
14    pub validation: Option<ValidationConfig>,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "lowercase")]
19pub enum SourceConfig {
20    HuggingFace {
21        dataset: String,
22        #[serde(default)]
23        subset: Option<String>,
24        #[serde(default = "default_split")]
25        split: String,
26        #[serde(default)]
27        sample: Option<usize>,
28        #[serde(default)]
29        columns: Option<Vec<String>>,
30    },
31    Local {
32        path: PathBuf,
33        format: FileFormat,
34        #[serde(default)]
35        sample: Option<usize>,
36    },
37}
38
39fn default_split() -> String {
40    "train".to_string()
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[serde(rename_all = "lowercase")]
45pub enum FileFormat {
46    Json,
47    Jsonl,
48    Csv,
49    Parquet,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type", rename_all = "lowercase")]
54pub enum ProviderConfig {
55    OpenAI {
56        model: String,
57        #[serde(default)]
58        api_key: Option<String>,
59        #[serde(default)]
60        base_url: Option<String>,
61        #[serde(default)]
62        temperature: Option<f32>,
63        #[serde(default)]
64        max_tokens: Option<u32>,
65    },
66    Anthropic {
67        model: String,
68        #[serde(default)]
69        api_key: Option<String>,
70        #[serde(default)]
71        temperature: Option<f32>,
72        #[serde(default)]
73        max_tokens: Option<u32>,
74    },
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct GenerationConfig {
79    pub task: GenerationTask,
80    #[serde(default = "default_count")]
81    pub count: usize,
82    #[serde(default)]
83    pub count_per_example: Option<usize>,
84    #[serde(default = "default_concurrency")]
85    pub concurrency: usize,
86    #[serde(default)]
87    pub strategy: Option<GenerationStrategy>,
88    #[serde(default)]
89    pub strategy_config: HashMap<String, serde_yaml::Value>,
90    #[serde(default)]
91    pub template: Option<String>,
92    #[serde(default)]
93    pub system_prompt: Option<String>,
94    #[serde(default)]
95    pub categories: Option<Vec<String>>,
96}
97
98fn default_count() -> usize {
99    100
100}
101
102fn default_concurrency() -> usize {
103    5
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107#[serde(rename_all = "snake_case")]
108pub enum GenerationTask {
109    Generate,
110    Augment,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "snake_case")]
115pub enum GenerationStrategy {
116    Paraphrase,
117    StyleTransfer,
118    BackTranslation,
119    Custom,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct OutputConfig {
124    pub format: OutputFormat,
125    pub path: PathBuf,
126    #[serde(default = "default_batch_size")]
127    pub batch_size: usize,
128}
129
130fn default_batch_size() -> usize {
131    100
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135#[serde(rename_all = "lowercase")]
136pub enum OutputFormat {
137    Json,
138    Jsonl,
139    Csv,
140    Parquet,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize, Default)]
144pub struct ValidationConfig {
145    #[serde(default)]
146    pub min_length: Option<usize>,
147    #[serde(default)]
148    pub max_length: Option<usize>,
149    #[serde(default)]
150    pub json: bool,
151    #[serde(default)]
152    pub json_schema: Option<Vec<String>>,
153    #[serde(default)]
154    pub blocklist: bool,
155    #[serde(default)]
156    pub repetition: bool,
157    #[serde(default)]
158    pub dedupe: Option<DedupeStrategy>,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
162#[serde(rename_all = "lowercase")]
163pub enum DedupeStrategy {
164    Exact,
165    Normalized,
166    Jaccard,
167}
168
169impl SynthConfig {
170    pub fn from_yaml(content: &str) -> crate::Result<Self> {
171        serde_yaml::from_str(content).map_err(Into::into)
172    }
173
174    pub fn from_file(path: &PathBuf) -> crate::Result<Self> {
175        let content = std::fs::read_to_string(path)?;
176        Self::from_yaml(&content)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_parse_augment_config() {
186        let yaml = r#"
187name: "sentiment_augmentation"
188
189source:
190  type: huggingface
191  dataset: "cornell-movie-review-data/rotten_tomatoes"
192  split: "train"
193  sample: 1000
194
195provider:
196  type: openai
197  model: "gpt-4o-mini"
198
199generation:
200  task: augment
201  count_per_example: 3
202  concurrency: 10
203  strategy: paraphrase
204
205output:
206  format: jsonl
207  path: "./output/augmented.jsonl"
208"#;
209
210        let config = SynthConfig::from_yaml(yaml).unwrap();
211        assert_eq!(config.name, "sentiment_augmentation");
212        assert!(matches!(
213            config.source,
214            Some(SourceConfig::HuggingFace { .. })
215        ));
216        assert!(matches!(config.provider, ProviderConfig::OpenAI { .. }));
217        assert!(matches!(config.generation.task, GenerationTask::Augment));
218    }
219
220    #[test]
221    fn test_parse_generate_config() {
222        let yaml = r#"
223name: "product_reviews"
224
225provider:
226  type: anthropic
227  model: "claude-haiku-4-5-20251001"
228
229generation:
230  task: generate
231  count: 500
232  concurrency: 5
233  categories:
234    - electronics
235    - books
236    - clothing
237  template: |
238    Generate a realistic {category} product review.
239    Output only the review text.
240
241output:
242  format: parquet
243  path: "./output/reviews.parquet"
244"#;
245
246        let config = SynthConfig::from_yaml(yaml).unwrap();
247        assert_eq!(config.name, "product_reviews");
248        assert!(config.source.is_none());
249        assert!(matches!(config.provider, ProviderConfig::Anthropic { .. }));
250        assert!(matches!(config.generation.task, GenerationTask::Generate));
251        assert_eq!(config.generation.categories.as_ref().unwrap().len(), 3);
252    }
253
254    #[test]
255    fn test_parse_validation_config() {
256        let yaml = r#"
257name: "with_validation"
258
259provider:
260  type: openai
261  model: "gpt-4o-mini"
262
263generation:
264  task: generate
265  count: 10
266  template: "Generate JSON: {\"q\": \"...\", \"a\": \"...\"}"
267
268output:
269  format: jsonl
270  path: "./output.jsonl"
271
272validation:
273  min_length: 20
274  max_length: 1000
275  json: true
276  json_schema:
277    - question
278    - answer
279  blocklist: true
280  repetition: true
281  dedupe: normalized
282"#;
283
284        let config = SynthConfig::from_yaml(yaml).unwrap();
285        let v = config.validation.unwrap();
286        assert_eq!(v.min_length, Some(20));
287        assert_eq!(v.max_length, Some(1000));
288        assert!(v.json);
289        assert_eq!(v.json_schema.unwrap(), vec!["question", "answer"]);
290        assert!(v.blocklist);
291        assert!(v.repetition);
292        assert!(matches!(v.dedupe, Some(DedupeStrategy::Normalized)));
293    }
294}